Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/search/classic/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ float Edge::GetP() const {
std::string Edge::DebugString() const {
std::ostringstream oss;
oss << "Move: " << move_.ToString(true) << " p_: " << p_
<< " GetP: " << GetP();
<< " GetP: " << GetP() << " n0: " << virtual_n_
<< " w0: " << virtual_wl_ << " d0: " << virtual_d_;
return oss.str();
}

Expand Down
70 changes: 66 additions & 4 deletions src/search/classic/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,15 @@ class Edge {
float GetP() const;
void SetP(float val);

float GetVirtualN() const { return virtual_n_; }
float GetVirtualWL() const { return virtual_wl_; }
float GetVirtualD() const { return virtual_d_; }
void SetVirtualStats(float n0, float wl0, float d0) {
virtual_n_ = n0;
virtual_wl_ = wl0;
virtual_d_ = d0;
}

// Debug information about the edge.
std::string DebugString() const;

Expand All @@ -108,6 +117,11 @@ class Edge {
// Probability that this move will be made, from the policy head of the neural
// network; compressed to a 16 bit format (5 bits exp, 11 bits significand).
uint16_t p_ = 0;

// Virtual statistics applied once during expansion.
float virtual_n_ = 0.0f;
float virtual_wl_ = 0.0f;
float virtual_d_ = 0.0f;
friend class Node;
};

Expand Down Expand Up @@ -373,13 +387,56 @@ class EdgeAndNode {

// Proxy functions for easier access to node/edge.
float GetQ(float default_q, float draw_score) const {
return (node_ && node_->GetN() > 0) ? node_->GetQ(draw_score) : default_q;
const float virtual_n = edge_ ? edge_->GetVirtualN() : 0.0f;
const float virtual_wl = edge_ ? edge_->GetVirtualWL() : 0.0f;
const float virtual_d = edge_ ? edge_->GetVirtualD() : 0.0f;
if (node_ && node_->GetN() > 0) {
const float real_n = node_->GetN();
const float total_n = real_n + virtual_n;
if (total_n > 0.0f) {
const float wl_sum = node_->GetWL() * real_n + virtual_wl;
const float d_sum = node_->GetD() * real_n + virtual_d;
const float avg_wl = wl_sum / total_n;
const float avg_d = d_sum / total_n;
return avg_wl + draw_score * avg_d;
}
}
if (virtual_n > 0.0f) {
const float avg_wl = virtual_wl / virtual_n;
const float avg_d = virtual_d / virtual_n;
return avg_wl + draw_score * avg_d;
}
return default_q;
}
float GetWL(float default_wl) const {
return (node_ && node_->GetN() > 0) ? node_->GetWL() : default_wl;
const float virtual_n = edge_ ? edge_->GetVirtualN() : 0.0f;
const float virtual_wl = edge_ ? edge_->GetVirtualWL() : 0.0f;
if (node_ && node_->GetN() > 0) {
const float real_n = node_->GetN();
const float total_n = real_n + virtual_n;
if (total_n > 0.0f) {
return (node_->GetWL() * real_n + virtual_wl) / total_n;
}
}
if (virtual_n > 0.0f) {
return virtual_wl / virtual_n;
}
return default_wl;
}
float GetD(float default_d) const {
return (node_ && node_->GetN() > 0) ? node_->GetD() : default_d;
const float virtual_n = edge_ ? edge_->GetVirtualN() : 0.0f;
const float virtual_d = edge_ ? edge_->GetVirtualD() : 0.0f;
if (node_ && node_->GetN() > 0) {
const float real_n = node_->GetN();
const float total_n = real_n + virtual_n;
if (total_n > 0.0f) {
return (node_->GetD() * real_n + virtual_d) / total_n;
}
}
if (virtual_n > 0.0f) {
return virtual_d / virtual_n;
}
return default_d;
}
float GetM(float default_m) const {
return (node_ && node_->GetN() > 0) ? node_->GetM() : default_m;
Expand All @@ -406,9 +463,14 @@ class EdgeAndNode {
// Returns U = numerator * p / N.
// Passed numerator is expected to be equal to (cpuct * sqrt(N[parent])).
float GetU(float numerator) const {
return numerator * GetP() / (1 + GetNStarted());
const float virtual_n = edge_ ? edge_->GetVirtualN() : 0.0f;
return numerator * GetP() / (1.0f + GetNStarted() + virtual_n);
}

float GetVirtualN() const { return edge_ ? edge_->GetVirtualN() : 0.0f; }
float GetVirtualWL() const { return edge_ ? edge_->GetVirtualWL() : 0.0f; }
float GetVirtualD() const { return edge_ ? edge_->GetVirtualD() : 0.0f; }

std::string DebugString() const;

protected:
Expand Down
5 changes: 5 additions & 0 deletions src/search/classic/params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,9 @@ const OptionId BaseSearchParams::kCpuctFactorAtRootId{
.uci_option = "CPuctFactorAtRoot",
.help_text = "Multiplier for the cpuct growth formula at root.",
.visibility = OptionId::kProOnly}};
const OptionId BaseSearchParams::kNodePriorAlphaId{
"node-prior-alpha", "NodePriorAlpha",
"Strength multiplier for virtual node priors (K = alpha * legal moves)."};
// Remove this option after 0.25 has been made mandatory in training and the
// training server stops sending it.
const OptionId BaseSearchParams::kRootHasOwnCpuctParamsId{
Expand Down Expand Up @@ -550,6 +553,7 @@ void BaseSearchParams::Populate(OptionsParser* options) {
options->Add<FloatOption>(kCpuctBaseAtRootId, 1.0f, 1000000000.0f) = 38739.0f;
options->Add<FloatOption>(kCpuctFactorId, 0.0f, 1000.0f) = 3.894f;
options->Add<FloatOption>(kCpuctFactorAtRootId, 0.0f, 1000.0f) = 3.894f;
options->Add<FloatOption>(kNodePriorAlphaId, 0.0f, 1000.0f) = 0.0f;
options->Add<BoolOption>(kRootHasOwnCpuctParamsId) = false;
options->Add<BoolOption>(kTwoFoldDrawsId) = true;
options->Add<FloatOption>(kTemperatureId, 0.0f, 100.0f) = 0.0f;
Expand Down Expand Up @@ -653,6 +657,7 @@ BaseSearchParams::BaseSearchParams(const OptionsDict& options)
kCpuctFactorAtRoot(options.Get<float>(
options.Get<bool>(kRootHasOwnCpuctParamsId) ? kCpuctFactorAtRootId
: kCpuctFactorId)),
kNodePriorAlpha(options.Get<float>(kNodePriorAlphaId)),
kTwoFoldDraws(options.Get<bool>(kTwoFoldDrawsId)),
kNoiseEpsilon(options.Get<float>(kNoiseEpsilonId)),
kNoiseAlpha(options.Get<float>(kNoiseAlphaId)),
Expand Down
3 changes: 3 additions & 0 deletions src/search/classic/params.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class BaseSearchParams {
float GetCpuctFactor(bool at_root) const {
return at_root ? kCpuctFactorAtRoot : kCpuctFactor;
}
float GetNodePriorAlpha() const { return kNodePriorAlpha; }
bool GetTwoFoldDraws() const { return kTwoFoldDraws; }
float GetTemperature() const { return options_.Get<float>(kTemperatureId); }
float GetTemperatureVisitOffset() const {
Expand Down Expand Up @@ -171,6 +172,7 @@ class BaseSearchParams {
static const OptionId kCpuctBaseAtRootId;
static const OptionId kCpuctFactorId;
static const OptionId kCpuctFactorAtRootId;
static const OptionId kNodePriorAlphaId;
static const OptionId kRootHasOwnCpuctParamsId;
static const OptionId kTwoFoldDrawsId;
static const OptionId kTemperatureId;
Expand Down Expand Up @@ -248,6 +250,7 @@ class BaseSearchParams {
const float kCpuctBaseAtRoot;
const float kCpuctFactor;
const float kCpuctFactorAtRoot;
const float kNodePriorAlpha;
const bool kTwoFoldDraws;
const float kNoiseEpsilon;
const float kNoiseAlpha;
Expand Down
36 changes: 27 additions & 9 deletions src/search/classic/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1696,7 +1696,8 @@ void SearchWorker::PickNodesToExtendTask(
for (Node* child : node->VisitedNodes()) {
int index = child->Index();
visited_pol += current_pol[index];
float q = child->GetQ(draw_score);
EdgeAndNode edge_wrapper(node->GetEdgeToNode(child), child);
float q = edge_wrapper.GetQ(child->GetQ(draw_score), draw_score);
current_util[index] = q + m_evaluator.GetMUtility(child, q);
}
const float fpu =
Expand Down Expand Up @@ -1730,10 +1731,13 @@ void SearchWorker::PickNodesToExtendTask(
current_nstarted[idx] = cur_iters[idx].GetNStarted();
}
int nstarted = current_nstarted[idx];
const float util = current_util[idx];
float util = current_util[idx];
if (idx > cache_filled_idx) {
current_score[idx] =
current_pol[idx] * puct_mult / (1 + nstarted) + util;
const float virtual_n = cur_iters[idx].GetVirtualN();
current_score[idx] = current_pol[idx] * puct_mult /
(1.0f + nstarted + virtual_n) +
current_util[idx];
util = current_util[idx];
cache_filled_idx++;
}
if (is_root_node) {
Expand Down Expand Up @@ -1779,7 +1783,8 @@ void SearchWorker::PickNodesToExtendTask(
if (second_best_edge) {
int estimated_visits_to_change_best = std::numeric_limits<int>::max();
if (best_without_u < second_best) {
const auto n1 = current_nstarted[best_idx] + 1;
const float virtual_n = cur_iters[best_idx].GetVirtualN();
const auto n1 = current_nstarted[best_idx] + 1 + virtual_n;
estimated_visits_to_change_best = static_cast<int>(
std::max(1.0f, std::min(current_pol[best_idx] * puct_mult /
(second_best - best_without_u) -
Expand Down Expand Up @@ -1816,9 +1821,11 @@ void SearchWorker::PickNodesToExtendTask(
child_node->IncrementNInFlight(new_visits);
current_nstarted[best_idx] += new_visits;
}
current_score[best_idx] = current_pol[best_idx] * puct_mult /
(1 + current_nstarted[best_idx]) +
current_util[best_idx];
current_score[best_idx] =
current_pol[best_idx] * puct_mult /
(1.0f + current_nstarted[best_idx] +
cur_iters[best_idx].GetVirtualN()) +
current_util[best_idx];
}
if ((decremented &&
(child_node->GetN() == 0 || child_node->IsTerminal()))) {
Expand Down Expand Up @@ -2114,7 +2121,7 @@ int SearchWorker::PrefetchIntoCache(Node* node, int budget, bool is_odd_depth) {
if (next_score > q) {
budget_to_spend =
std::min(budget, int(edge.GetP() * puct_mult / (next_score - q) -
edge.GetNStarted()) +
(edge.GetNStarted() + edge.GetVirtualN())) +
1);
} else {
budget_to_spend = budget;
Expand Down Expand Up @@ -2181,6 +2188,17 @@ void SearchWorker::FetchSingleNodeResult(NodeToProcess* node_to_process) {
ApplyDirichletNoise(node, params_.GetNoiseEpsilon(),
params_.GetNoiseAlpha());
}
const float node_prior_alpha = params_.GetNodePriorAlpha();
const float prior_strength =
node_prior_alpha > 0.0f ? node_prior_alpha * node->GetNumEdges() : 0.0f;
const float parent_wl = node->GetWL();
const float parent_d = node->GetD();
for (auto edge_it : node->Edges()) {
float n0 = prior_strength * edge_it.GetP();
float w0 = n0 * parent_wl;
float d0 = n0 * parent_d;
edge_it.edge()->SetVirtualStats(n0, w0, d0);
}
node->SortEdges();
}

Expand Down
3 changes: 2 additions & 1 deletion src/search/dag_classic/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ float Edge::GetP() const {
std::string Edge::DebugString() const {
std::ostringstream oss;
oss << "Move: " << move_.ToString(true) << " p_: " << p_
<< " GetP: " << GetP();
<< " GetP: " << GetP() << " n0: " << virtual_n_
<< " w0: " << virtual_wl_ << " d0: " << virtual_d_;
return oss.str();
}

Expand Down
68 changes: 64 additions & 4 deletions src/search/dag_classic/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,15 @@ class Edge {
float GetP() const;
void SetP(float val);

float GetVirtualN() const { return virtual_n_; }
float GetVirtualWL() const { return virtual_wl_; }
float GetVirtualD() const { return virtual_d_; }
void SetVirtualStats(float n0, float wl0, float d0) {
virtual_n_ = n0;
virtual_wl_ = wl0;
virtual_d_ = d0;
}

// Debug information about the edge.
std::string DebugString() const;

Expand All @@ -203,6 +212,9 @@ class Edge {
// Probability that this move will be made, from the policy head of the neural
// network; compressed to a 16 bit format (5 bits exp, 11 bits significand).
uint16_t p_ = 0;
float virtual_n_ = 0.0f;
float virtual_wl_ = 0.0f;
float virtual_d_ = 0.0f;
friend class Node;
};

Expand Down Expand Up @@ -692,13 +704,56 @@ class EdgeAndNode {

// Proxy functions for easier access to node/edge.
float GetQ(float default_q, float draw_score) const {
return (node_ && node_->GetN() > 0) ? node_->GetQ(draw_score) : default_q;
const float virtual_n = edge_ ? edge_->GetVirtualN() : 0.0f;
const float virtual_wl = edge_ ? edge_->GetVirtualWL() : 0.0f;
const float virtual_d = edge_ ? edge_->GetVirtualD() : 0.0f;
if (node_ && node_->GetN() > 0) {
const float real_n = node_->GetN();
const float total_n = real_n + virtual_n;
if (total_n > 0.0f) {
const float wl_sum = node_->GetWL() * real_n + virtual_wl;
const float d_sum = node_->GetD() * real_n + virtual_d;
const float avg_wl = wl_sum / total_n;
const float avg_d = d_sum / total_n;
return avg_wl + draw_score * avg_d;
}
}
if (virtual_n > 0.0f) {
const float avg_wl = virtual_wl / virtual_n;
const float avg_d = virtual_d / virtual_n;
return avg_wl + draw_score * avg_d;
}
return default_q;
}
float GetWL(float default_wl) const {
return (node_ && node_->GetN() > 0) ? node_->GetWL() : default_wl;
const float virtual_n = edge_ ? edge_->GetVirtualN() : 0.0f;
const float virtual_wl = edge_ ? edge_->GetVirtualWL() : 0.0f;
if (node_ && node_->GetN() > 0) {
const float real_n = node_->GetN();
const float total_n = real_n + virtual_n;
if (total_n > 0.0f) {
return (node_->GetWL() * real_n + virtual_wl) / total_n;
}
}
if (virtual_n > 0.0f) {
return virtual_wl / virtual_n;
}
return default_wl;
}
float GetD(float default_d) const {
return (node_ && node_->GetN() > 0) ? node_->GetD() : default_d;
const float virtual_n = edge_ ? edge_->GetVirtualN() : 0.0f;
const float virtual_d = edge_ ? edge_->GetVirtualD() : 0.0f;
if (node_ && node_->GetN() > 0) {
const float real_n = node_->GetN();
const float total_n = real_n + virtual_n;
if (total_n > 0.0f) {
return (node_->GetD() * real_n + virtual_d) / total_n;
}
}
if (virtual_n > 0.0f) {
return virtual_d / virtual_n;
}
return default_d;
}
float GetM(float default_m) const {
return (node_ && node_->GetN() > 0) ? node_->GetM() : default_m;
Expand Down Expand Up @@ -727,9 +782,14 @@ class EdgeAndNode {
// Returns U = numerator * p / N.
// Passed numerator is expected to be equal to (cpuct * sqrt(N[parent])).
float GetU(float numerator) const {
return numerator * GetP() / (1 + GetNStarted());
const float virtual_n = edge_ ? edge_->GetVirtualN() : 0.0f;
return numerator * GetP() / (1.0f + GetNStarted() + virtual_n);
}

float GetVirtualN() const { return edge_ ? edge_->GetVirtualN() : 0.0f; }
float GetVirtualWL() const { return edge_ ? edge_->GetVirtualWL() : 0.0f; }
float GetVirtualD() const { return edge_ ? edge_->GetVirtualD() : 0.0f; }

std::string DebugString() const;

protected:
Expand Down
Loading