diff --git a/search/uct.py b/search/uct.py index 34d7db8..08f3699 100644 --- a/search/uct.py +++ b/search/uct.py @@ -67,9 +67,24 @@ def get_best_move(root): score = int(round(cp(node.Q()),0)) return bestmove, node, score +def getBest(node): + bestmove, node = max(node.children.items(), key=lambda item: (item[1].number_visits, item[1].Q())) + return bestmove, node + def send_info(send, bestmove, count, delta, score): if send != None: - send("info depth 1 seldepth 1 score cp {} nodes {} nps {} pv {}".format(score, count, int(round(count/delta, 0)), bestmove)) + send("info depth {} seldepth 1 score cp {} nodes {} nps {} pv {}".format(bestmove[1], score, count, int(round(count/delta, 0)), bestmove[0])) + +def pv(root): + d = 0 + pv = "" + current = root + while current.is_expanded: + d += 1 + bestmove, node = getBest(current) + pv += bestmove + " " + current = node + return pv, d def UCT_search(board, num_reads, net=None, C=1.0, verbose=False, max_time=None, tree=None, send=None): if max_time == None: @@ -93,7 +108,7 @@ def UCT_search(board, num_reads, net=None, C=1.0, verbose=False, max_time=None, if (delta - delta_last > 5): delta_last = delta bestmove, node, score = get_best_move(root) - send_info(send, bestmove, count, delta, score) + send_info(send, pv(root), count, delta, score) if (time != None) and (delta > max_time): break @@ -102,7 +117,8 @@ def UCT_search(board, num_reads, net=None, C=1.0, verbose=False, max_time=None, if send != None: for nd in sorted(root.children.items(), key= lambda item: item[1].number_visits): send("info string {} {} \t(P: {}%) \t(Q: {})".format(nd[1].move, nd[1].number_visits, round(nd[1].prior*100,2), round(nd[1].Q(), 5))) - send("info depth 1 seldepth 1 score cp {} nodes {} nps {} pv {}".format(score, count, int(round(count/delta, 0)), bestmove)) + pa, d = pv(root) + send("info depth {} seldepth 1 score cp {} nodes {} nps {} pv {}".format(d, score, count, int(round(count/delta, 0)), pa[0])) # if we have a bad score, go for a draw return bestmove, score