diff --git a/search/uct.py b/search/uct.py index 2204db4..52a17d3 100644 --- a/search/uct.py +++ b/search/uct.py @@ -9,9 +9,8 @@ FPU_ROOT = 0.0 class UCTNode(): - def __init__(self, board=None, parent=None, move=None, prior=0): + def __init__(self, board=None, parent=None, prior=0): self.board = board - self.move = move self.is_expanded = False self.parent = parent # Optional[UCTNode] self.children = OrderedDict() # Dict[move, UCTNode] @@ -32,14 +31,19 @@ def U(self): # returns float def best_child(self, C): return max(self.children.values(), key=lambda node: node.Q() + C*node.U()) + + def best_move_and_child(self, C): + return max(self.children.items(), + key=lambda move, node: node.Q() + C*node.U()) def select_leaf(self, C): current = self + move = None while current.is_expanded and current.children: - current = current.best_child(C) + move, current = current.best_move_and_child(C) if not current.board: current.board = current.parent.board.copy() - current.board.push_uci(current.move) + current.board.push_uci(move) return current def expand(self, child_priors): @@ -48,7 +52,7 @@ def expand(self, child_priors): self.add_child(move, prior) def add_child(self, move, prior): - self.children[move] = UCTNode(parent=self, move=move, prior=prior) + self.children[move] = UCTNode(parent=self, prior=prior) def backup(self, value_estimate: float): current = self @@ -87,8 +91,8 @@ def UCT_search(board, num_reads, net=None, C=1.0, verbose=False, max_time=None, bestmove, node = max(root.children.items(), key=lambda item: (item[1].number_visits, item[1].Q())) score = int(round(cp(node.Q()),0)) if send != None: - for nd in sorted(root.children.items(), key= lambda item: item[1].number_visits): - send("info string {} {} \t(P: {}%) \t(Q: {})".format(nd[1].move, nd[1].number_visits, round(nd[1].prior*100,2), round(nd[1].Q(), 5))) + for nd in sorted(root.children.items(), key= lambda move, item: item[1].number_visits): + send("info string {} {} \t(P: {}%) \t(Q: {})".format(nd[0], nd[1].number_visits, round(nd[1].prior*100,2), round(nd[1].Q(), 5))) send("info depth 1 seldepth 1 score cp {} nodes {} nps {} pv {}".format(score, count, int(round(count/delta, 0)), bestmove)) # if we have a bad score, go for a draw