Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions search/uct.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
FPU_ROOT = 0.0

class UCTNode():
def __init__(self, board=None, parent=None, move=None, prior=0):
def __init__(self, board=None, parent=None, prior=0):
self.board = board
self.move = move
self.is_expanded = False
self.parent = parent # Optional[UCTNode]
self.children = OrderedDict() # Dict[move, UCTNode]
Expand All @@ -32,14 +31,19 @@ def U(self): # returns float
def best_child(self, C):
return max(self.children.values(),
key=lambda node: node.Q() + C*node.U())

def best_move_and_child(self, C):
return max(self.children.items(),
key=lambda move, node: node.Q() + C*node.U())

def select_leaf(self, C):
current = self
move = None
while current.is_expanded and current.children:
current = current.best_child(C)
move, current = current.best_move_and_child(C)
if not current.board:
current.board = current.parent.board.copy()
current.board.push_uci(current.move)
current.board.push_uci(move)
return current

def expand(self, child_priors):
Expand All @@ -48,7 +52,7 @@ def expand(self, child_priors):
self.add_child(move, prior)

def add_child(self, move, prior):
self.children[move] = UCTNode(parent=self, move=move, prior=prior)
self.children[move] = UCTNode(parent=self, prior=prior)

def backup(self, value_estimate: float):
current = self
Expand Down Expand Up @@ -87,8 +91,8 @@ def UCT_search(board, num_reads, net=None, C=1.0, verbose=False, max_time=None,
bestmove, node = max(root.children.items(), key=lambda item: (item[1].number_visits, item[1].Q()))
score = int(round(cp(node.Q()),0))
if send != None:
for nd in sorted(root.children.items(), key= lambda item: item[1].number_visits):
send("info string {} {} \t(P: {}%) \t(Q: {})".format(nd[1].move, nd[1].number_visits, round(nd[1].prior*100,2), round(nd[1].Q(), 5)))
for nd in sorted(root.children.items(), key= lambda move, item: item[1].number_visits):
send("info string {} {} \t(P: {}%) \t(Q: {})".format(nd[0], nd[1].number_visits, round(nd[1].prior*100,2), round(nd[1].Q(), 5)))
send("info depth 1 seldepth 1 score cp {} nodes {} nps {} pv {}".format(score, count, int(round(count/delta, 0)), bestmove))

# if we have a bad score, go for a draw
Expand Down