Skip to content

Commit 941eea7

Browse files
authored
Expose visitor pattern for traversing trees (#733)
This refactors the implementation of traversing the nodes in a tree to expose a full visitor-pattern interface. The existing preorder and postorder traversal modes are now implemented as special cases of the visitor pattern. This re-implements two algorithms to eliminate recursion by using the visitor pattern: * Determining if a game is constant-sum * Determining the minimum/maximum payoff to a player in any play of the game.
1 parent a63fdab commit 941eea7

3 files changed

Lines changed: 240 additions & 144 deletions

File tree

src/games/game.h

Lines changed: 152 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <list>
2727
#include <set>
2828
#include <stack>
29+
#include <queue>
2930
#include <memory>
3031

3132
#include "number.h"
@@ -627,101 +628,180 @@ class GameRep : public std::enable_shared_from_this<GameRep> {
627628
using Players = ElementCollection<Game, GamePlayerRep>;
628629
using Outcomes = ElementCollection<Game, GameOutcomeRep>;
629630

630-
class Nodes {
631-
Game m_owner{nullptr};
632-
TraversalOrder m_order{TraversalOrder::Preorder};
631+
enum class DFSCallbackResult {
632+
Continue, // continue with normal traversal
633+
Prune, // skip the subtree under this node
634+
Stop // end traversal early now
635+
};
633636

634-
public:
635-
class iterator {
636-
friend class Nodes;
637+
/// @brief Perform a depth-first traversal of the game tree.
638+
///
639+
/// WalkDFS performs a depth-first traversal of the game tree starting at
640+
/// the node @p p_root. Traversal order (preorder or postorder) controls when
641+
/// the OnVisit() hook is called relative to the other callbacks.
642+
///
643+
/// @tparam Callback A callback type implementing the required interface (see below)).
644+
/// @param p_game The game object being traversed
645+
/// @param p_root The root node of the subtree to traverse.
646+
/// @param p_order Controls when OnVisit() is invoked:
647+
/// - TraversalOrder::Preorder: OnVisit() is called when a node is first entered.
648+
/// - TraversalOrder::Postorder: OnVisit() is called after all children have been processed.
649+
/// @param p_callback The callback object
650+
///
651+
/// The callback object must implement the following four member functions:
652+
///
653+
/// @code
654+
/// DFSCallbackResult OnEnter(GameNode node, int depth);
655+
/// DFSCallbackResult OnAction(GameNode parent, GameNode child, int depth);
656+
/// DFSCallbackResult OnExit(GameNode node, int depth);
657+
/// void OnVisit(GameNode node, int depth);
658+
/// @endcode
659+
///
660+
/// OnEnter(node, depth)
661+
/// --------------------
662+
/// Invoked exactly once for each node, when the node is first reached during DFS.
663+
/// @p depth is the depth of the node relative to @p p_root (where the root is defined
664+
/// to have depth 0).
665+
///
666+
/// OnAction(parent, child, depth)
667+
/// ------------------------------
668+
/// Invoked immediately before descending from @p parent to one of its children.
669+
/// @p depth is the depth of the parent node.
670+
///
671+
/// OnExit(node, depth)
672+
/// -------------------
673+
/// Invoked exactly once for each node, after all of its children have been
674+
/// processed (or skipped). @p is the depth of the node.
675+
///
676+
/// OnVisit(node, depth)
677+
/// --------------------
678+
/// Invoked exactly once per node, at a time determined by @p p_order:
679+
/// - Preorder: called during OnEnter(), before any children are visited.
680+
/// - Postorder: called after OnExit(), once all children are complete.
681+
///
682+
template <class Callback>
683+
static void WalkDFS(const Game &p_game, const GameNode &p_root, TraversalOrder p_order,
684+
Callback &p_callback)
685+
{
686+
using ChildIterator = ElementCollection<GameNode, GameNodeRep>::iterator;
687+
688+
struct Frame {
689+
GameNode m_node;
690+
ChildIterator m_current;
691+
ChildIterator m_end;
692+
int m_depth{};
693+
bool m_entered{}, m_pruned{};
694+
};
637695

638-
using ChildIterator = ElementCollection<GameNode, GameNodeRep>::iterator;
696+
if (!p_root) {
697+
return;
698+
}
699+
if (p_root->GetGame() != p_game) {
700+
throw MismatchException();
701+
}
639702

640-
struct Frame {
641-
GameNode m_node;
642-
ChildIterator m_current, m_end;
643-
};
703+
std::stack<Frame> stack;
704+
stack.push(Frame{p_root, {}, {}, 0, false, false});
644705

645-
Game m_owner{nullptr};
646-
TraversalOrder m_order{TraversalOrder::Preorder};
647-
std::stack<Frame> m_stack{};
648-
GameNode m_current{nullptr};
706+
while (!stack.empty()) {
707+
Frame &f = stack.top();
649708

650-
iterator(const Game &p_game, const GameNode &p_start, const TraversalOrder p_order)
651-
: m_owner(p_game), m_order(p_order)
652-
{
653-
if (!p_start) {
709+
if (!f.m_entered) {
710+
f.m_entered = true;
711+
const DFSCallbackResult result = p_callback.OnEnter(f.m_node, f.m_depth);
712+
if (result == DFSCallbackResult::Stop) {
654713
return;
655714
}
656-
if (p_start->GetGame() != p_game) {
657-
throw MismatchException();
715+
if (result == DFSCallbackResult::Prune) {
716+
f.m_pruned = true;
658717
}
659-
m_stack.push(make_frame(p_start));
660-
if (m_order == TraversalOrder::Preorder) {
661-
m_current = p_start;
718+
if (p_order == TraversalOrder::Preorder) {
719+
p_callback.OnVisit(f.m_node, f.m_depth);
662720
}
663-
else {
664-
descend_postorder();
665-
m_current = m_stack.empty() ? nullptr : m_stack.top().m_node;
666-
}
667-
if (!m_current) {
668-
m_owner = nullptr;
721+
if (!f.m_pruned && !f.m_node->IsTerminal()) {
722+
auto children = f.m_node->GetChildren();
723+
f.m_current = children.begin();
724+
f.m_end = children.end();
669725
}
670726
}
671727

672-
static Frame make_frame(const GameNode &p_node)
673-
{
674-
if (p_node->IsTerminal()) {
675-
return Frame{p_node, {}, {}};
728+
if (!f.m_pruned && !f.m_node->IsTerminal() && f.m_current != f.m_end) {
729+
GameNode const child = *f.m_current;
730+
++f.m_current;
731+
const DFSCallbackResult result = p_callback.OnAction(f.m_node, child, f.m_depth);
732+
733+
if (result == DFSCallbackResult::Stop) {
734+
return;
735+
}
736+
if (result == DFSCallbackResult::Prune) {
737+
continue;
676738
}
677-
const auto children = p_node->GetChildren();
678-
return Frame{p_node, children.begin(), children.end()};
739+
stack.push(Frame{child, {}, {}, f.m_depth + 1, false, false});
740+
continue;
679741
}
680742

681-
GameNode advance_preorder()
682-
{
683-
if (auto &[node, current, m_end] = m_stack.top(); !node->IsTerminal()) {
684-
const auto children = node->GetChildren();
685-
current = children.begin();
686-
m_end = children.end();
743+
const DFSCallbackResult result = p_callback.OnExit(f.m_node, f.m_depth);
744+
if (result == DFSCallbackResult::Stop) {
745+
return;
746+
}
747+
if (p_order == TraversalOrder::Postorder) {
748+
p_callback.OnVisit(f.m_node, f.m_depth);
749+
}
750+
stack.pop();
751+
}
752+
}
753+
754+
class Nodes {
755+
Game m_owner{nullptr};
756+
TraversalOrder m_order{TraversalOrder::Preorder};
757+
758+
public:
759+
class iterator {
760+
friend class Nodes;
761+
762+
struct NodeHandler {
763+
std::queue<GameNode> m_queue;
764+
765+
static DFSCallbackResult OnEnter(const GameNode &, int)
766+
{
767+
return DFSCallbackResult::Continue;
687768
}
688-
while (!m_stack.empty()) {
689-
if (auto &f = m_stack.top(); f.m_current != f.m_end) {
690-
GameNode const next = *f.m_current;
691-
++f.m_current;
692-
m_stack.push(make_frame(next));
693-
return next;
694-
}
695-
m_stack.pop();
769+
static DFSCallbackResult OnAction(const GameNode &, const GameNode &, int)
770+
{
771+
return DFSCallbackResult::Continue;
696772
}
697-
return nullptr;
698-
}
773+
static DFSCallbackResult OnExit(const GameNode &, int)
774+
{
775+
return DFSCallbackResult::Continue;
776+
}
777+
void OnVisit(const GameNode &p_node, int) { m_queue.push(p_node); }
778+
};
699779

700-
void descend_postorder()
780+
Game m_owner{nullptr};
781+
std::shared_ptr<NodeHandler> m_handler;
782+
GameNode m_current{nullptr};
783+
784+
iterator(const Game &p_game, TraversalOrder p_order)
785+
: m_owner(p_game), m_handler(std::make_shared<NodeHandler>())
701786
{
702-
while (!m_stack.empty()) {
703-
auto &f = m_stack.top();
704-
if (f.m_current == f.m_end) {
705-
return;
706-
}
707-
const auto child = *f.m_current;
708-
++f.m_current;
709-
m_stack.push(make_frame(child));
710-
if (!child->IsTerminal()) {
711-
continue;
712-
}
787+
if (!p_game) {
788+
m_owner = nullptr;
713789
return;
714790
}
791+
WalkDFS(p_game, p_game->GetRoot(), p_order, *m_handler);
792+
advance();
715793
}
716794

717-
GameNode advance_postorder()
795+
void advance()
718796
{
719-
m_stack.pop();
720-
if (m_stack.empty()) {
721-
return nullptr;
797+
if (!m_handler || m_handler->m_queue.empty()) {
798+
m_current = nullptr;
799+
m_owner = nullptr;
800+
}
801+
else {
802+
m_current = m_handler->m_queue.front();
803+
m_handler->m_queue.pop();
722804
}
723-
descend_postorder();
724-
return m_stack.empty() ? nullptr : m_stack.top().m_node;
725805
}
726806

727807
public:
@@ -732,8 +812,6 @@ class GameRep : public std::enable_shared_from_this<GameRep> {
732812
using pointer = GameNode;
733813

734814
iterator() = default;
735-
iterator(const iterator &) = default;
736-
iterator &operator=(const iterator &) = default;
737815

738816
value_type operator*() const
739817
{
@@ -742,21 +820,11 @@ class GameRep : public std::enable_shared_from_this<GameRep> {
742820
}
743821
return m_current;
744822
}
745-
746823
iterator &operator++()
747824
{
748-
if (!m_current) {
749-
return *this;
750-
}
751-
const auto next =
752-
(m_order == TraversalOrder::Preorder) ? advance_preorder() : advance_postorder();
753-
m_current = next;
754-
if (!m_current) {
755-
m_owner = nullptr;
756-
}
825+
advance();
757826
return *this;
758827
}
759-
760828
bool operator==(const iterator &p_other) const
761829
{
762830
return m_owner == p_other.m_owner && m_current == p_other.m_current;
@@ -770,10 +838,7 @@ class GameRep : public std::enable_shared_from_this<GameRep> {
770838
{
771839
}
772840

773-
iterator begin() const
774-
{
775-
return (m_owner) ? iterator{m_owner, m_owner->GetRoot(), m_order} : iterator{};
776-
}
841+
iterator begin() const { return m_owner ? iterator{m_owner, m_order} : iterator{}; }
777842
static iterator end() { return iterator{}; }
778843
};
779844

0 commit comments

Comments
 (0)