diff --git a/src/search/classic/node.cc b/src/search/classic/node.cc index d473d16a90..ea63d9246d 100644 --- a/src/search/classic/node.cc +++ b/src/search/classic/node.cc @@ -178,7 +178,8 @@ float Edge::GetP() const { std::string Edge::DebugString() const { std::ostringstream oss; oss << "Move: " << move_.ToString(true) << " p_: " << p_ - << " GetP: " << GetP(); + << " GetP: " << GetP() << " n0: " << virtual_n_ + << " w0: " << virtual_wl_ << " d0: " << virtual_d_; return oss.str(); } diff --git a/src/search/classic/node.h b/src/search/classic/node.h index 8a4e598fdb..aa957c1cce 100644 --- a/src/search/classic/node.h +++ b/src/search/classic/node.h @@ -96,6 +96,15 @@ class Edge { float GetP() const; void SetP(float val); + float GetVirtualN() const { return virtual_n_; } + float GetVirtualWL() const { return virtual_wl_; } + float GetVirtualD() const { return virtual_d_; } + void SetVirtualStats(float n0, float wl0, float d0) { + virtual_n_ = n0; + virtual_wl_ = wl0; + virtual_d_ = d0; + } + // Debug information about the edge. std::string DebugString() const; @@ -108,6 +117,11 @@ class Edge { // Probability that this move will be made, from the policy head of the neural // network; compressed to a 16 bit format (5 bits exp, 11 bits significand). uint16_t p_ = 0; + + // Virtual statistics applied once during expansion. + float virtual_n_ = 0.0f; + float virtual_wl_ = 0.0f; + float virtual_d_ = 0.0f; friend class Node; }; @@ -373,13 +387,56 @@ class EdgeAndNode { // Proxy functions for easier access to node/edge. float GetQ(float default_q, float draw_score) const { - return (node_ && node_->GetN() > 0) ? node_->GetQ(draw_score) : default_q; + const float virtual_n = edge_ ? edge_->GetVirtualN() : 0.0f; + const float virtual_wl = edge_ ? edge_->GetVirtualWL() : 0.0f; + const float virtual_d = edge_ ? edge_->GetVirtualD() : 0.0f; + if (node_ && node_->GetN() > 0) { + const float real_n = node_->GetN(); + const float total_n = real_n + virtual_n; + if (total_n > 0.0f) { + const float wl_sum = node_->GetWL() * real_n + virtual_wl; + const float d_sum = node_->GetD() * real_n + virtual_d; + const float avg_wl = wl_sum / total_n; + const float avg_d = d_sum / total_n; + return avg_wl + draw_score * avg_d; + } + } + if (virtual_n > 0.0f) { + const float avg_wl = virtual_wl / virtual_n; + const float avg_d = virtual_d / virtual_n; + return avg_wl + draw_score * avg_d; + } + return default_q; } float GetWL(float default_wl) const { - return (node_ && node_->GetN() > 0) ? node_->GetWL() : default_wl; + const float virtual_n = edge_ ? edge_->GetVirtualN() : 0.0f; + const float virtual_wl = edge_ ? edge_->GetVirtualWL() : 0.0f; + if (node_ && node_->GetN() > 0) { + const float real_n = node_->GetN(); + const float total_n = real_n + virtual_n; + if (total_n > 0.0f) { + return (node_->GetWL() * real_n + virtual_wl) / total_n; + } + } + if (virtual_n > 0.0f) { + return virtual_wl / virtual_n; + } + return default_wl; } float GetD(float default_d) const { - return (node_ && node_->GetN() > 0) ? node_->GetD() : default_d; + const float virtual_n = edge_ ? edge_->GetVirtualN() : 0.0f; + const float virtual_d = edge_ ? edge_->GetVirtualD() : 0.0f; + if (node_ && node_->GetN() > 0) { + const float real_n = node_->GetN(); + const float total_n = real_n + virtual_n; + if (total_n > 0.0f) { + return (node_->GetD() * real_n + virtual_d) / total_n; + } + } + if (virtual_n > 0.0f) { + return virtual_d / virtual_n; + } + return default_d; } float GetM(float default_m) const { return (node_ && node_->GetN() > 0) ? node_->GetM() : default_m; @@ -406,9 +463,14 @@ class EdgeAndNode { // Returns U = numerator * p / N. // Passed numerator is expected to be equal to (cpuct * sqrt(N[parent])). float GetU(float numerator) const { - return numerator * GetP() / (1 + GetNStarted()); + const float virtual_n = edge_ ? edge_->GetVirtualN() : 0.0f; + return numerator * GetP() / (1.0f + GetNStarted() + virtual_n); } + float GetVirtualN() const { return edge_ ? edge_->GetVirtualN() : 0.0f; } + float GetVirtualWL() const { return edge_ ? edge_->GetVirtualWL() : 0.0f; } + float GetVirtualD() const { return edge_ ? edge_->GetVirtualD() : 0.0f; } + std::string DebugString() const; protected: diff --git a/src/search/classic/params.cc b/src/search/classic/params.cc index e61b0f9c88..e59906a521 100644 --- a/src/search/classic/params.cc +++ b/src/search/classic/params.cc @@ -208,6 +208,9 @@ const OptionId BaseSearchParams::kCpuctFactorAtRootId{ .uci_option = "CPuctFactorAtRoot", .help_text = "Multiplier for the cpuct growth formula at root.", .visibility = OptionId::kProOnly}}; +const OptionId BaseSearchParams::kNodePriorAlphaId{ + "node-prior-alpha", "NodePriorAlpha", + "Strength multiplier for virtual node priors (K = alpha * legal moves)."}; // Remove this option after 0.25 has been made mandatory in training and the // training server stops sending it. const OptionId BaseSearchParams::kRootHasOwnCpuctParamsId{ @@ -550,6 +553,7 @@ void BaseSearchParams::Populate(OptionsParser* options) { options->Add(kCpuctBaseAtRootId, 1.0f, 1000000000.0f) = 38739.0f; options->Add(kCpuctFactorId, 0.0f, 1000.0f) = 3.894f; options->Add(kCpuctFactorAtRootId, 0.0f, 1000.0f) = 3.894f; + options->Add(kNodePriorAlphaId, 0.0f, 1000.0f) = 0.0f; options->Add(kRootHasOwnCpuctParamsId) = false; options->Add(kTwoFoldDrawsId) = true; options->Add(kTemperatureId, 0.0f, 100.0f) = 0.0f; @@ -653,6 +657,7 @@ BaseSearchParams::BaseSearchParams(const OptionsDict& options) kCpuctFactorAtRoot(options.Get( options.Get(kRootHasOwnCpuctParamsId) ? kCpuctFactorAtRootId : kCpuctFactorId)), + kNodePriorAlpha(options.Get(kNodePriorAlphaId)), kTwoFoldDraws(options.Get(kTwoFoldDrawsId)), kNoiseEpsilon(options.Get(kNoiseEpsilonId)), kNoiseAlpha(options.Get(kNoiseAlphaId)), diff --git a/src/search/classic/params.h b/src/search/classic/params.h index d84dbad5d8..b7faae5553 100644 --- a/src/search/classic/params.h +++ b/src/search/classic/params.h @@ -63,6 +63,7 @@ class BaseSearchParams { float GetCpuctFactor(bool at_root) const { return at_root ? kCpuctFactorAtRoot : kCpuctFactor; } + float GetNodePriorAlpha() const { return kNodePriorAlpha; } bool GetTwoFoldDraws() const { return kTwoFoldDraws; } float GetTemperature() const { return options_.Get(kTemperatureId); } float GetTemperatureVisitOffset() const { @@ -171,6 +172,7 @@ class BaseSearchParams { static const OptionId kCpuctBaseAtRootId; static const OptionId kCpuctFactorId; static const OptionId kCpuctFactorAtRootId; + static const OptionId kNodePriorAlphaId; static const OptionId kRootHasOwnCpuctParamsId; static const OptionId kTwoFoldDrawsId; static const OptionId kTemperatureId; @@ -248,6 +250,7 @@ class BaseSearchParams { const float kCpuctBaseAtRoot; const float kCpuctFactor; const float kCpuctFactorAtRoot; + const float kNodePriorAlpha; const bool kTwoFoldDraws; const float kNoiseEpsilon; const float kNoiseAlpha; diff --git a/src/search/classic/search.cc b/src/search/classic/search.cc index 101d62e941..4ab8bfd495 100644 --- a/src/search/classic/search.cc +++ b/src/search/classic/search.cc @@ -1696,7 +1696,8 @@ void SearchWorker::PickNodesToExtendTask( for (Node* child : node->VisitedNodes()) { int index = child->Index(); visited_pol += current_pol[index]; - float q = child->GetQ(draw_score); + EdgeAndNode edge_wrapper(node->GetEdgeToNode(child), child); + float q = edge_wrapper.GetQ(child->GetQ(draw_score), draw_score); current_util[index] = q + m_evaluator.GetMUtility(child, q); } const float fpu = @@ -1730,10 +1731,13 @@ void SearchWorker::PickNodesToExtendTask( current_nstarted[idx] = cur_iters[idx].GetNStarted(); } int nstarted = current_nstarted[idx]; - const float util = current_util[idx]; + float util = current_util[idx]; if (idx > cache_filled_idx) { - current_score[idx] = - current_pol[idx] * puct_mult / (1 + nstarted) + util; + const float virtual_n = cur_iters[idx].GetVirtualN(); + current_score[idx] = current_pol[idx] * puct_mult / + (1.0f + nstarted + virtual_n) + + current_util[idx]; + util = current_util[idx]; cache_filled_idx++; } if (is_root_node) { @@ -1779,7 +1783,8 @@ void SearchWorker::PickNodesToExtendTask( if (second_best_edge) { int estimated_visits_to_change_best = std::numeric_limits::max(); if (best_without_u < second_best) { - const auto n1 = current_nstarted[best_idx] + 1; + const float virtual_n = cur_iters[best_idx].GetVirtualN(); + const auto n1 = current_nstarted[best_idx] + 1 + virtual_n; estimated_visits_to_change_best = static_cast( std::max(1.0f, std::min(current_pol[best_idx] * puct_mult / (second_best - best_without_u) - @@ -1816,9 +1821,11 @@ void SearchWorker::PickNodesToExtendTask( child_node->IncrementNInFlight(new_visits); current_nstarted[best_idx] += new_visits; } - current_score[best_idx] = current_pol[best_idx] * puct_mult / - (1 + current_nstarted[best_idx]) + - current_util[best_idx]; + current_score[best_idx] = + current_pol[best_idx] * puct_mult / + (1.0f + current_nstarted[best_idx] + + cur_iters[best_idx].GetVirtualN()) + + current_util[best_idx]; } if ((decremented && (child_node->GetN() == 0 || child_node->IsTerminal()))) { @@ -2114,7 +2121,7 @@ int SearchWorker::PrefetchIntoCache(Node* node, int budget, bool is_odd_depth) { if (next_score > q) { budget_to_spend = std::min(budget, int(edge.GetP() * puct_mult / (next_score - q) - - edge.GetNStarted()) + + (edge.GetNStarted() + edge.GetVirtualN())) + 1); } else { budget_to_spend = budget; @@ -2181,6 +2188,17 @@ void SearchWorker::FetchSingleNodeResult(NodeToProcess* node_to_process) { ApplyDirichletNoise(node, params_.GetNoiseEpsilon(), params_.GetNoiseAlpha()); } + const float node_prior_alpha = params_.GetNodePriorAlpha(); + const float prior_strength = + node_prior_alpha > 0.0f ? node_prior_alpha * node->GetNumEdges() : 0.0f; + const float parent_wl = node->GetWL(); + const float parent_d = node->GetD(); + for (auto edge_it : node->Edges()) { + float n0 = prior_strength * edge_it.GetP(); + float w0 = n0 * parent_wl; + float d0 = n0 * parent_d; + edge_it.edge()->SetVirtualStats(n0, w0, d0); + } node->SortEdges(); } diff --git a/src/search/dag_classic/node.cc b/src/search/dag_classic/node.cc index 1ad4762f0e..f718aba973 100644 --- a/src/search/dag_classic/node.cc +++ b/src/search/dag_classic/node.cc @@ -105,7 +105,8 @@ float Edge::GetP() const { std::string Edge::DebugString() const { std::ostringstream oss; oss << "Move: " << move_.ToString(true) << " p_: " << p_ - << " GetP: " << GetP(); + << " GetP: " << GetP() << " n0: " << virtual_n_ + << " w0: " << virtual_wl_ << " d0: " << virtual_d_; return oss.str(); } diff --git a/src/search/dag_classic/node.h b/src/search/dag_classic/node.h index b74a64a9d4..4f451dcb8a 100644 --- a/src/search/dag_classic/node.h +++ b/src/search/dag_classic/node.h @@ -189,6 +189,15 @@ class Edge { float GetP() const; void SetP(float val); + float GetVirtualN() const { return virtual_n_; } + float GetVirtualWL() const { return virtual_wl_; } + float GetVirtualD() const { return virtual_d_; } + void SetVirtualStats(float n0, float wl0, float d0) { + virtual_n_ = n0; + virtual_wl_ = wl0; + virtual_d_ = d0; + } + // Debug information about the edge. std::string DebugString() const; @@ -203,6 +212,9 @@ class Edge { // Probability that this move will be made, from the policy head of the neural // network; compressed to a 16 bit format (5 bits exp, 11 bits significand). uint16_t p_ = 0; + float virtual_n_ = 0.0f; + float virtual_wl_ = 0.0f; + float virtual_d_ = 0.0f; friend class Node; }; @@ -692,13 +704,56 @@ class EdgeAndNode { // Proxy functions for easier access to node/edge. float GetQ(float default_q, float draw_score) const { - return (node_ && node_->GetN() > 0) ? node_->GetQ(draw_score) : default_q; + const float virtual_n = edge_ ? edge_->GetVirtualN() : 0.0f; + const float virtual_wl = edge_ ? edge_->GetVirtualWL() : 0.0f; + const float virtual_d = edge_ ? edge_->GetVirtualD() : 0.0f; + if (node_ && node_->GetN() > 0) { + const float real_n = node_->GetN(); + const float total_n = real_n + virtual_n; + if (total_n > 0.0f) { + const float wl_sum = node_->GetWL() * real_n + virtual_wl; + const float d_sum = node_->GetD() * real_n + virtual_d; + const float avg_wl = wl_sum / total_n; + const float avg_d = d_sum / total_n; + return avg_wl + draw_score * avg_d; + } + } + if (virtual_n > 0.0f) { + const float avg_wl = virtual_wl / virtual_n; + const float avg_d = virtual_d / virtual_n; + return avg_wl + draw_score * avg_d; + } + return default_q; } float GetWL(float default_wl) const { - return (node_ && node_->GetN() > 0) ? node_->GetWL() : default_wl; + const float virtual_n = edge_ ? edge_->GetVirtualN() : 0.0f; + const float virtual_wl = edge_ ? edge_->GetVirtualWL() : 0.0f; + if (node_ && node_->GetN() > 0) { + const float real_n = node_->GetN(); + const float total_n = real_n + virtual_n; + if (total_n > 0.0f) { + return (node_->GetWL() * real_n + virtual_wl) / total_n; + } + } + if (virtual_n > 0.0f) { + return virtual_wl / virtual_n; + } + return default_wl; } float GetD(float default_d) const { - return (node_ && node_->GetN() > 0) ? node_->GetD() : default_d; + const float virtual_n = edge_ ? edge_->GetVirtualN() : 0.0f; + const float virtual_d = edge_ ? edge_->GetVirtualD() : 0.0f; + if (node_ && node_->GetN() > 0) { + const float real_n = node_->GetN(); + const float total_n = real_n + virtual_n; + if (total_n > 0.0f) { + return (node_->GetD() * real_n + virtual_d) / total_n; + } + } + if (virtual_n > 0.0f) { + return virtual_d / virtual_n; + } + return default_d; } float GetM(float default_m) const { return (node_ && node_->GetN() > 0) ? node_->GetM() : default_m; @@ -727,9 +782,14 @@ class EdgeAndNode { // Returns U = numerator * p / N. // Passed numerator is expected to be equal to (cpuct * sqrt(N[parent])). float GetU(float numerator) const { - return numerator * GetP() / (1 + GetNStarted()); + const float virtual_n = edge_ ? edge_->GetVirtualN() : 0.0f; + return numerator * GetP() / (1.0f + GetNStarted() + virtual_n); } + float GetVirtualN() const { return edge_ ? edge_->GetVirtualN() : 0.0f; } + float GetVirtualWL() const { return edge_ ? edge_->GetVirtualWL() : 0.0f; } + float GetVirtualD() const { return edge_ ? edge_->GetVirtualD() : 0.0f; } + std::string DebugString() const; protected: diff --git a/src/search/dag_classic/search.cc b/src/search/dag_classic/search.cc index e65977c38c..e6f672a90d 100644 --- a/src/search/dag_classic/search.cc +++ b/src/search/dag_classic/search.cc @@ -1792,8 +1792,11 @@ void SearchWorker::PickNodesToExtendTask( float visited_pol = 0.0f; for (Node* child : node->VisitedNodes()) { int index = child->Index(); - visited_pol += child->GetP(); - float q = child->GetQ(draw_score); + Edge* parent_edges = node->GetLowNode()->GetEdges(); + EdgeAndNode edge_wrapper(parent_edges ? parent_edges + index : nullptr, + child); + visited_pol += edge_wrapper.GetP(); + float q = edge_wrapper.GetQ(child->GetQ(draw_score), draw_score); current_util[index] = q + m_evaluator.GetMUtility(child, q); } const float fpu = @@ -1828,10 +1831,13 @@ void SearchWorker::PickNodesToExtendTask( current_nstarted[idx] = cur_iters[idx].GetNStarted(); } int nstarted = current_nstarted[idx]; - const float util = current_util[idx]; + float util = current_util[idx]; if (idx > cache_filled_idx) { - current_score[idx] = - cur_iters[idx].GetP() * puct_mult / (1 + nstarted) + util; + const float virtual_n = cur_iters[idx].GetVirtualN(); + current_score[idx] = cur_iters[idx].GetP() * puct_mult / + (1.0f + nstarted + virtual_n) + + current_util[idx]; + util = current_util[idx]; cache_filled_idx++; } if (is_root_node) { @@ -1877,7 +1883,8 @@ void SearchWorker::PickNodesToExtendTask( if (second_best_edge) { int estimated_visits_to_change_best = std::numeric_limits::max(); if (best_without_u < second_best) { - const auto n1 = current_nstarted[best_idx] + 1; + const float virtual_n = cur_iters[best_idx].GetVirtualN(); + const auto n1 = current_nstarted[best_idx] + 1 + virtual_n; estimated_visits_to_change_best = static_cast( std::max(1.0f, std::min(cur_iters[best_idx].GetP() * puct_mult / (second_best - best_without_u) - @@ -1917,9 +1924,11 @@ void SearchWorker::PickNodesToExtendTask( child_node->IncrementNInFlight(new_visits); current_nstarted[best_idx] += new_visits; } - current_score[best_idx] = cur_iters[best_idx].GetP() * puct_mult / - (1 + current_nstarted[best_idx]) + - current_util[best_idx]; + current_score[best_idx] = + cur_iters[best_idx].GetP() * puct_mult / + (1.0f + current_nstarted[best_idx] + + cur_iters[best_idx].GetVirtualN()) + + current_util[best_idx]; } if (best_idx > vtp_last_filled.back() && (*visits_to_perform.back())[best_idx] > 0) { @@ -2154,6 +2163,19 @@ void SearchWorker::FetchSingleNodeResult(NodeToProcess* node_to_process) { params_.GetNoiseEpsilon(), params_.GetNoiseAlpha()); node_to_process->tt_low_node->SortEdges(); } + const float node_prior_alpha = params_.GetNodePriorAlpha(); + const float prior_strength = + node_prior_alpha > 0.0f + ? node_prior_alpha * node_to_process->tt_low_node->GetNumEdges() + : 0.0f; + const float parent_wl = node->GetWL(); + const float parent_d = node->GetD(); + Edge* edges = node_to_process->tt_low_node->GetEdges(); + for (int i = 0, num = node_to_process->tt_low_node->GetNumEdges(); i < num; + ++i) { + float n0 = prior_strength * edges[i].GetP(); + edges[i].SetVirtualStats(n0, n0 * parent_wl, n0 * parent_d); + } } // 6. Propagate the new nodes' information to all their parents in the tree.