diff --git a/src/games/game.h b/src/games/game.h index 4b8cc9143..66cf5ae00 100644 --- a/src/games/game.h +++ b/src/games/game.h @@ -26,6 +26,7 @@ #include #include #include +#include #include #include "number.h" @@ -627,101 +628,180 @@ class GameRep : public std::enable_shared_from_this { using Players = ElementCollection; using Outcomes = ElementCollection; - class Nodes { - Game m_owner{nullptr}; - TraversalOrder m_order{TraversalOrder::Preorder}; + enum class DFSCallbackResult { + Continue, // continue with normal traversal + Prune, // skip the subtree under this node + Stop // end traversal early now + }; - public: - class iterator { - friend class Nodes; + /// @brief Perform a depth-first traversal of the game tree. + /// + /// WalkDFS performs a depth-first traversal of the game tree starting at + /// the node @p p_root. Traversal order (preorder or postorder) controls when + /// the OnVisit() hook is called relative to the other callbacks. + /// + /// @tparam Callback A callback type implementing the required interface (see below)). + /// @param p_game The game object being traversed + /// @param p_root The root node of the subtree to traverse. + /// @param p_order Controls when OnVisit() is invoked: + /// - TraversalOrder::Preorder: OnVisit() is called when a node is first entered. + /// - TraversalOrder::Postorder: OnVisit() is called after all children have been processed. + /// @param p_callback The callback object + /// + /// The callback object must implement the following four member functions: + /// + /// @code + /// DFSCallbackResult OnEnter(GameNode node, int depth); + /// DFSCallbackResult OnAction(GameNode parent, GameNode child, int depth); + /// DFSCallbackResult OnExit(GameNode node, int depth); + /// void OnVisit(GameNode node, int depth); + /// @endcode + /// + /// OnEnter(node, depth) + /// -------------------- + /// Invoked exactly once for each node, when the node is first reached during DFS. + /// @p depth is the depth of the node relative to @p p_root (where the root is defined + /// to have depth 0). + /// + /// OnAction(parent, child, depth) + /// ------------------------------ + /// Invoked immediately before descending from @p parent to one of its children. + /// @p depth is the depth of the parent node. + /// + /// OnExit(node, depth) + /// ------------------- + /// Invoked exactly once for each node, after all of its children have been + /// processed (or skipped). @p is the depth of the node. + /// + /// OnVisit(node, depth) + /// -------------------- + /// Invoked exactly once per node, at a time determined by @p p_order: + /// - Preorder: called during OnEnter(), before any children are visited. + /// - Postorder: called after OnExit(), once all children are complete. + /// + template + static void WalkDFS(const Game &p_game, const GameNode &p_root, TraversalOrder p_order, + Callback &p_callback) + { + using ChildIterator = ElementCollection::iterator; + + struct Frame { + GameNode m_node; + ChildIterator m_current; + ChildIterator m_end; + int m_depth{}; + bool m_entered{}, m_pruned{}; + }; - using ChildIterator = ElementCollection::iterator; + if (!p_root) { + return; + } + if (p_root->GetGame() != p_game) { + throw MismatchException(); + } - struct Frame { - GameNode m_node; - ChildIterator m_current, m_end; - }; + std::stack stack; + stack.push(Frame{p_root, {}, {}, 0, false, false}); - Game m_owner{nullptr}; - TraversalOrder m_order{TraversalOrder::Preorder}; - std::stack m_stack{}; - GameNode m_current{nullptr}; + while (!stack.empty()) { + Frame &f = stack.top(); - iterator(const Game &p_game, const GameNode &p_start, const TraversalOrder p_order) - : m_owner(p_game), m_order(p_order) - { - if (!p_start) { + if (!f.m_entered) { + f.m_entered = true; + const DFSCallbackResult result = p_callback.OnEnter(f.m_node, f.m_depth); + if (result == DFSCallbackResult::Stop) { return; } - if (p_start->GetGame() != p_game) { - throw MismatchException(); + if (result == DFSCallbackResult::Prune) { + f.m_pruned = true; } - m_stack.push(make_frame(p_start)); - if (m_order == TraversalOrder::Preorder) { - m_current = p_start; + if (p_order == TraversalOrder::Preorder) { + p_callback.OnVisit(f.m_node, f.m_depth); } - else { - descend_postorder(); - m_current = m_stack.empty() ? nullptr : m_stack.top().m_node; - } - if (!m_current) { - m_owner = nullptr; + if (!f.m_pruned && !f.m_node->IsTerminal()) { + auto children = f.m_node->GetChildren(); + f.m_current = children.begin(); + f.m_end = children.end(); } } - static Frame make_frame(const GameNode &p_node) - { - if (p_node->IsTerminal()) { - return Frame{p_node, {}, {}}; + if (!f.m_pruned && !f.m_node->IsTerminal() && f.m_current != f.m_end) { + GameNode const child = *f.m_current; + ++f.m_current; + const DFSCallbackResult result = p_callback.OnAction(f.m_node, child, f.m_depth); + + if (result == DFSCallbackResult::Stop) { + return; + } + if (result == DFSCallbackResult::Prune) { + continue; } - const auto children = p_node->GetChildren(); - return Frame{p_node, children.begin(), children.end()}; + stack.push(Frame{child, {}, {}, f.m_depth + 1, false, false}); + continue; } - GameNode advance_preorder() - { - if (auto &[node, current, m_end] = m_stack.top(); !node->IsTerminal()) { - const auto children = node->GetChildren(); - current = children.begin(); - m_end = children.end(); + const DFSCallbackResult result = p_callback.OnExit(f.m_node, f.m_depth); + if (result == DFSCallbackResult::Stop) { + return; + } + if (p_order == TraversalOrder::Postorder) { + p_callback.OnVisit(f.m_node, f.m_depth); + } + stack.pop(); + } + } + + class Nodes { + Game m_owner{nullptr}; + TraversalOrder m_order{TraversalOrder::Preorder}; + + public: + class iterator { + friend class Nodes; + + struct NodeHandler { + std::queue m_queue; + + static DFSCallbackResult OnEnter(const GameNode &, int) + { + return DFSCallbackResult::Continue; } - while (!m_stack.empty()) { - if (auto &f = m_stack.top(); f.m_current != f.m_end) { - GameNode const next = *f.m_current; - ++f.m_current; - m_stack.push(make_frame(next)); - return next; - } - m_stack.pop(); + static DFSCallbackResult OnAction(const GameNode &, const GameNode &, int) + { + return DFSCallbackResult::Continue; } - return nullptr; - } + static DFSCallbackResult OnExit(const GameNode &, int) + { + return DFSCallbackResult::Continue; + } + void OnVisit(const GameNode &p_node, int) { m_queue.push(p_node); } + }; - void descend_postorder() + Game m_owner{nullptr}; + std::shared_ptr m_handler; + GameNode m_current{nullptr}; + + iterator(const Game &p_game, TraversalOrder p_order) + : m_owner(p_game), m_handler(std::make_shared()) { - while (!m_stack.empty()) { - auto &f = m_stack.top(); - if (f.m_current == f.m_end) { - return; - } - const auto child = *f.m_current; - ++f.m_current; - m_stack.push(make_frame(child)); - if (!child->IsTerminal()) { - continue; - } + if (!p_game) { + m_owner = nullptr; return; } + WalkDFS(p_game, p_game->GetRoot(), p_order, *m_handler); + advance(); } - GameNode advance_postorder() + void advance() { - m_stack.pop(); - if (m_stack.empty()) { - return nullptr; + if (!m_handler || m_handler->m_queue.empty()) { + m_current = nullptr; + m_owner = nullptr; + } + else { + m_current = m_handler->m_queue.front(); + m_handler->m_queue.pop(); } - descend_postorder(); - return m_stack.empty() ? nullptr : m_stack.top().m_node; } public: @@ -732,8 +812,6 @@ class GameRep : public std::enable_shared_from_this { using pointer = GameNode; iterator() = default; - iterator(const iterator &) = default; - iterator &operator=(const iterator &) = default; value_type operator*() const { @@ -742,21 +820,11 @@ class GameRep : public std::enable_shared_from_this { } return m_current; } - iterator &operator++() { - if (!m_current) { - return *this; - } - const auto next = - (m_order == TraversalOrder::Preorder) ? advance_preorder() : advance_postorder(); - m_current = next; - if (!m_current) { - m_owner = nullptr; - } + advance(); return *this; } - bool operator==(const iterator &p_other) const { return m_owner == p_other.m_owner && m_current == p_other.m_current; @@ -770,10 +838,7 @@ class GameRep : public std::enable_shared_from_this { { } - iterator begin() const - { - return (m_owner) ? iterator{m_owner, m_owner->GetRoot(), m_order} : iterator{}; - } + iterator begin() const { return m_owner ? iterator{m_owner, m_order} : iterator{}; } static iterator end() { return iterator{}; } }; diff --git a/src/games/gametree.cc b/src/games/gametree.cc index ba5302d5c..3df2700c0 100644 --- a/src/games/gametree.cc +++ b/src/games/gametree.cc @@ -739,83 +739,112 @@ Game NewTree() { return std::make_shared(); } // GameTreeRep: General data access //------------------------------------------------------------------------ -namespace { - -class NotZeroSumException final : public std::runtime_error { -public: - NotZeroSumException() : std::runtime_error("Game is not constant sum") {} - ~NotZeroSumException() noexcept override = default; -}; - -Rational SubtreeSum(GameNode p_node) +bool GameTreeRep::IsConstSum() const { - Rational sum(0); + struct ConstSumCallback { + const GameTreeRep *m_game; + std::map m_subtreeSums; + bool m_isConstSum{true}; - if (!p_node->IsTerminal()) { - const auto children = p_node->GetChildren(); - sum = SubtreeSum(children.front()); - if (std::any_of(std::next(children.begin()), children.end(), - [sum](const GameNode &n) { return SubtreeSum(n) != sum; })) { - throw NotZeroSumException(); + static DFSCallbackResult OnEnter(GameNode, int) { return DFSCallbackResult::Continue; } + static DFSCallbackResult OnAction(GameNode, GameNode, int) + { + return DFSCallbackResult::Continue; } - } + DFSCallbackResult OnExit(const GameNode &p_node, int) + { + Rational sum(0); + if (!p_node->IsTerminal()) { + const auto children = p_node->GetChildren(); + + if (std::adjacent_find(children.begin(), children.end(), + [&](const GameNode &a, const GameNode &b) { + return m_subtreeSums[a] != m_subtreeSums[b]; + }) != children.end()) { + m_isConstSum = false; + return DFSCallbackResult::Stop; + } + sum = m_subtreeSums[*children.begin()]; + for (const auto &child : children) { + m_subtreeSums.erase(child); + } + } - if (p_node->GetOutcome()) { - for (const auto &player : p_node->GetGame()->GetPlayers()) { - sum += p_node->GetOutcome()->GetPayoff(player); + if (const auto outcome = p_node->GetOutcome()) { + sum += sum_function(m_game->m_players, [&](const auto &p_player) { + return outcome->GetPayoff(p_player); + }); + } + m_subtreeSums[p_node] = sum; + return DFSCallbackResult::Continue; } - } - return sum; + static void OnVisit(GameNode, int) {} + }; + + ConstSumCallback callback{this}; + WalkDFS(Game(const_cast(this)->shared_from_this()), m_root, + TraversalOrder::Postorder, callback); + return callback.m_isConstSum; } -Rational -AggregateSubtreePayoff(const GamePlayer &p_player, const GameNode &p_node, - std::function p_aggregator) +template +Rational GameTreeRep::AggregateSubtreePayoff(const GamePlayer &p_player, + Aggregator p_aggregator) const { - if (p_node->IsTerminal()) { - if (p_node->GetOutcome()) { - return p_node->GetOutcome()->GetPayoff(p_player); + struct AggregatePayoffCallback { + const GamePlayer &m_player; + Aggregator m_aggregator; + + std::map m_subtreeValues; + Rational m_result{0}; + + static DFSCallbackResult OnEnter(GameNode, int) { return DFSCallbackResult::Continue; } + static DFSCallbackResult OnAction(GameNode, GameNode, int) + { + return DFSCallbackResult::Continue; + } + DFSCallbackResult OnExit(const GameNode &p_node, int) + { + Rational value(0); + if (!p_node->IsTerminal()) { + const auto children = p_node->GetChildren(); + value = m_aggregator(children, [&](const GameNode &c) { return m_subtreeValues[c]; }); + for (const auto &child : children) { + m_subtreeValues.erase(child); + } + } + if (const auto outcome = p_node->GetOutcome()) { + value += outcome->GetPayoff(m_player); + } + m_subtreeValues[p_node] = value; + m_result = value; // We write the root node value last, so will be correct on termination + return DFSCallbackResult::Continue; } - return Rational(0); - } - const auto &children = p_node->GetChildren(); - auto subtree = - std::accumulate(std::next(children.begin()), children.end(), - AggregateSubtreePayoff(p_player, children.front(), p_aggregator), - [&p_aggregator, &p_player](const Rational &r, const GameNode &c) { - return p_aggregator(r, AggregateSubtreePayoff(p_player, c, p_aggregator)); - }); - if (p_node->GetOutcome()) { - return subtree + p_node->GetOutcome()->GetPayoff(p_player); - } - return subtree; -} -} // end anonymous namespace + static void OnVisit(GameNode, int) {} + }; -bool GameTreeRep::IsConstSum() const -{ - try { - SubtreeSum(m_root); - return true; - } - catch (NotZeroSumException &) { - return false; - } + AggregatePayoffCallback callback{p_player, std::move(p_aggregator)}; + + WalkDFS(Game(const_cast(this)->shared_from_this()), m_root, + TraversalOrder::Postorder, callback); + + return callback.m_result; } Rational GameTreeRep::GetPlayerMinPayoff(const GamePlayer &p_player) const { - return AggregateSubtreePayoff( - p_player, m_root, [](const Rational &a, const Rational &b) { return std::min(a, b); }); + return AggregateSubtreePayoff(p_player, [](const auto &range, auto value_fn) { + return minimize_function(range, value_fn); + }); } Rational GameTreeRep::GetPlayerMaxPayoff(const GamePlayer &p_player) const { - return AggregateSubtreePayoff( - p_player, m_root, [](const Rational &a, const Rational &b) { return std::max(a, b); }); + return AggregateSubtreePayoff(p_player, [](const auto &range, auto value_fn) { + return maximize_function(range, value_fn); + }); } - bool GameTreeRep::IsPerfectRecall() const { if (!m_ownPriorActionInfo && !m_root->IsTerminal()) { diff --git a/src/games/gametree.h b/src/games/gametree.h index ead211837..03247b661 100644 --- a/src/games/gametree.h +++ b/src/games/gametree.h @@ -51,6 +51,8 @@ class GameTreeRep final : public GameExplicitRep { /// @name Private auxiliary functions //@{ static void SortInfosets(GamePlayerRep *); + template + Rational AggregateSubtreePayoff(const GamePlayer &p_player, Aggregator p_aggregator) const; static void RenumberInfosets(GamePlayerRep *); /// Normalize the probability distribution of actions at a chance node Game NormalizeChanceProbs(GameInfosetRep *);