Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 152 additions & 87 deletions src/games/game.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <list>
#include <set>
#include <stack>
#include <queue>
#include <memory>

#include "number.h"
Expand Down Expand Up @@ -627,101 +628,180 @@ class GameRep : public std::enable_shared_from_this<GameRep> {
using Players = ElementCollection<Game, GamePlayerRep>;
using Outcomes = ElementCollection<Game, GameOutcomeRep>;

class Nodes {
Game m_owner{nullptr};
TraversalOrder m_order{TraversalOrder::Preorder};
enum class DFSCallbackResult {
Continue, // continue with normal traversal
Prune, // skip the subtree under this node
Stop // end traversal early now
};

public:
class iterator {
friend class Nodes;
/// @brief Perform a depth-first traversal of the game tree.
///
/// WalkDFS performs a depth-first traversal of the game tree starting at
/// the node @p p_root. Traversal order (preorder or postorder) controls when
/// the OnVisit() hook is called relative to the other callbacks.
///
/// @tparam Callback A callback type implementing the required interface (see below)).
/// @param p_game The game object being traversed
/// @param p_root The root node of the subtree to traverse.
/// @param p_order Controls when OnVisit() is invoked:
/// - TraversalOrder::Preorder: OnVisit() is called when a node is first entered.
/// - TraversalOrder::Postorder: OnVisit() is called after all children have been processed.
/// @param p_callback The callback object
///
/// The callback object must implement the following four member functions:
///
/// @code
/// DFSCallbackResult OnEnter(GameNode node, int depth);
/// DFSCallbackResult OnAction(GameNode parent, GameNode child, int depth);
/// DFSCallbackResult OnExit(GameNode node, int depth);
/// void OnVisit(GameNode node, int depth);
/// @endcode
///
/// OnEnter(node, depth)
/// --------------------
/// Invoked exactly once for each node, when the node is first reached during DFS.
/// @p depth is the depth of the node relative to @p p_root (where the root is defined
/// to have depth 0).
///
/// OnAction(parent, child, depth)
/// ------------------------------
/// Invoked immediately before descending from @p parent to one of its children.
/// @p depth is the depth of the parent node.
///
/// OnExit(node, depth)
/// -------------------
/// Invoked exactly once for each node, after all of its children have been
/// processed (or skipped). @p is the depth of the node.
///
/// OnVisit(node, depth)
/// --------------------
/// Invoked exactly once per node, at a time determined by @p p_order:
/// - Preorder: called during OnEnter(), before any children are visited.
/// - Postorder: called after OnExit(), once all children are complete.
///
template <class Callback>
static void WalkDFS(const Game &p_game, const GameNode &p_root, TraversalOrder p_order,
Callback &p_callback)
{
using ChildIterator = ElementCollection<GameNode, GameNodeRep>::iterator;

struct Frame {
GameNode m_node;
ChildIterator m_current;
ChildIterator m_end;
int m_depth{};
bool m_entered{}, m_pruned{};
};

using ChildIterator = ElementCollection<GameNode, GameNodeRep>::iterator;
if (!p_root) {
return;
}
if (p_root->GetGame() != p_game) {
throw MismatchException();
}

struct Frame {
GameNode m_node;
ChildIterator m_current, m_end;
};
std::stack<Frame> stack;
stack.push(Frame{p_root, {}, {}, 0, false, false});

Game m_owner{nullptr};
TraversalOrder m_order{TraversalOrder::Preorder};
std::stack<Frame> m_stack{};
GameNode m_current{nullptr};
while (!stack.empty()) {
Frame &f = stack.top();

iterator(const Game &p_game, const GameNode &p_start, const TraversalOrder p_order)
: m_owner(p_game), m_order(p_order)
{
if (!p_start) {
if (!f.m_entered) {
f.m_entered = true;
const DFSCallbackResult result = p_callback.OnEnter(f.m_node, f.m_depth);
if (result == DFSCallbackResult::Stop) {
return;
}
if (p_start->GetGame() != p_game) {
throw MismatchException();
if (result == DFSCallbackResult::Prune) {
f.m_pruned = true;
}
m_stack.push(make_frame(p_start));
if (m_order == TraversalOrder::Preorder) {
m_current = p_start;
if (p_order == TraversalOrder::Preorder) {
p_callback.OnVisit(f.m_node, f.m_depth);
}
else {
descend_postorder();
m_current = m_stack.empty() ? nullptr : m_stack.top().m_node;
}
if (!m_current) {
m_owner = nullptr;
if (!f.m_pruned && !f.m_node->IsTerminal()) {
auto children = f.m_node->GetChildren();
f.m_current = children.begin();
f.m_end = children.end();
}
}

static Frame make_frame(const GameNode &p_node)
{
if (p_node->IsTerminal()) {
return Frame{p_node, {}, {}};
if (!f.m_pruned && !f.m_node->IsTerminal() && f.m_current != f.m_end) {
GameNode const child = *f.m_current;
++f.m_current;
const DFSCallbackResult result = p_callback.OnAction(f.m_node, child, f.m_depth);

if (result == DFSCallbackResult::Stop) {
return;
}
if (result == DFSCallbackResult::Prune) {
continue;
}
const auto children = p_node->GetChildren();
return Frame{p_node, children.begin(), children.end()};
stack.push(Frame{child, {}, {}, f.m_depth + 1, false, false});
continue;
}

GameNode advance_preorder()
{
if (auto &[node, current, m_end] = m_stack.top(); !node->IsTerminal()) {
const auto children = node->GetChildren();
current = children.begin();
m_end = children.end();
const DFSCallbackResult result = p_callback.OnExit(f.m_node, f.m_depth);
if (result == DFSCallbackResult::Stop) {
return;
}
if (p_order == TraversalOrder::Postorder) {
p_callback.OnVisit(f.m_node, f.m_depth);
}
stack.pop();
}
}

class Nodes {
Game m_owner{nullptr};
TraversalOrder m_order{TraversalOrder::Preorder};

public:
class iterator {
friend class Nodes;

struct NodeHandler {
std::queue<GameNode> m_queue;

static DFSCallbackResult OnEnter(const GameNode &, int)
{
return DFSCallbackResult::Continue;
}
while (!m_stack.empty()) {
if (auto &f = m_stack.top(); f.m_current != f.m_end) {
GameNode const next = *f.m_current;
++f.m_current;
m_stack.push(make_frame(next));
return next;
}
m_stack.pop();
static DFSCallbackResult OnAction(const GameNode &, const GameNode &, int)
{
return DFSCallbackResult::Continue;
}
return nullptr;
}
static DFSCallbackResult OnExit(const GameNode &, int)
{
return DFSCallbackResult::Continue;
}
void OnVisit(const GameNode &p_node, int) { m_queue.push(p_node); }
};

void descend_postorder()
Game m_owner{nullptr};
std::shared_ptr<NodeHandler> m_handler;
GameNode m_current{nullptr};

iterator(const Game &p_game, TraversalOrder p_order)
: m_owner(p_game), m_handler(std::make_shared<NodeHandler>())
{
while (!m_stack.empty()) {
auto &f = m_stack.top();
if (f.m_current == f.m_end) {
return;
}
const auto child = *f.m_current;
++f.m_current;
m_stack.push(make_frame(child));
if (!child->IsTerminal()) {
continue;
}
if (!p_game) {
m_owner = nullptr;
return;
}
WalkDFS(p_game, p_game->GetRoot(), p_order, *m_handler);
advance();
}

GameNode advance_postorder()
void advance()
{
m_stack.pop();
if (m_stack.empty()) {
return nullptr;
if (!m_handler || m_handler->m_queue.empty()) {
m_current = nullptr;
m_owner = nullptr;
}
else {
m_current = m_handler->m_queue.front();
m_handler->m_queue.pop();
}
descend_postorder();
return m_stack.empty() ? nullptr : m_stack.top().m_node;
}

public:
Expand All @@ -732,8 +812,6 @@ class GameRep : public std::enable_shared_from_this<GameRep> {
using pointer = GameNode;

iterator() = default;
iterator(const iterator &) = default;
iterator &operator=(const iterator &) = default;

value_type operator*() const
{
Expand All @@ -742,21 +820,11 @@ class GameRep : public std::enable_shared_from_this<GameRep> {
}
return m_current;
}

iterator &operator++()
{
if (!m_current) {
return *this;
}
const auto next =
(m_order == TraversalOrder::Preorder) ? advance_preorder() : advance_postorder();
m_current = next;
if (!m_current) {
m_owner = nullptr;
}
advance();
return *this;
}

bool operator==(const iterator &p_other) const
{
return m_owner == p_other.m_owner && m_current == p_other.m_current;
Expand All @@ -770,10 +838,7 @@ class GameRep : public std::enable_shared_from_this<GameRep> {
{
}

iterator begin() const
{
return (m_owner) ? iterator{m_owner, m_owner->GetRoot(), m_order} : iterator{};
}
iterator begin() const { return m_owner ? iterator{m_owner, m_order} : iterator{}; }
static iterator end() { return iterator{}; }
};

Expand Down
Loading