From 0b370da53dcc8325e1967200c67723b3f3ddbe40 Mon Sep 17 00:00:00 2001 From: drdkad Date: Sat, 21 Jun 2025 12:08:55 +0100 Subject: [PATCH 1/2] Implement Node class, allowing for iteration over game nodes in the DFT order. Change Python API to use the new nodes iterator. Add tests. --- ChangeLog | 2 + src/games/game.h | 92 +++++++++++++++++++++++++++++++++++++++++ src/pygambit/gambit.pxd | 10 +++++ src/pygambit/game.pxi | 9 ++-- tests/test_node.py | 28 +++++++++++++ 5 files changed, 136 insertions(+), 5 deletions(-) diff --git a/ChangeLog b/ChangeLog index c9ab9f6ac..fd83e3cc5 100644 --- a/ChangeLog +++ b/ChangeLog @@ -16,6 +16,8 @@ prescribed by a strategy at an information set - Tests for creation of the reduced strategic form from an extensive-form game (currently only for games with perfect recall) +- Implement a `Nodes` class as a member of `GameRep`. This class includes an iterator + that starts from the root node and successively returns nodes in depth-first traversal order (#530) ## [16.3.1] - unreleased diff --git a/src/games/game.h b/src/games/game.h index 7dd1d680c..ec23f1db8 100644 --- a/src/games/game.h +++ b/src/games/game.h @@ -25,6 +25,7 @@ #include #include +#include #include "number.h" #include "gameobject.h" @@ -61,6 +62,8 @@ template class ElementCollection { public: class iterator { + friend class GameRep; + P m_owner{nullptr}; const std::vector *m_container{nullptr}; size_t m_index{0}; @@ -512,6 +515,93 @@ class GameRep : public BaseGameRep { using Players = ElementCollection; using Outcomes = ElementCollection; + class Nodes { + Game m_owner{nullptr}; + + public: + class iterator { + friend class Nodes; + using ChildIterator = ElementCollection::iterator; + + Game m_owner{nullptr}; + GameNode m_current_node{nullptr}; + std::stack m_stack{}; + + iterator(Game game) : m_owner(game) {} + + public: + using iterator_category = std::forward_iterator_tag; + using value_type = GameNode; + using pointer = value_type *; + + iterator() = default; + + iterator(Game game, GameNode start_node) : m_owner(game), m_current_node(start_node) + { + if (!start_node) { + return; + } + if (start_node->GetGame() != m_owner) { + throw MismatchException(); + } + } + + value_type operator*() const + { + if (!m_current_node) { + throw std::runtime_error("Cannot dereference an end iterator"); + } + return m_current_node; + } + + iterator &operator++() + { + if (!m_current_node) { + throw std::out_of_range("Cannot increment an end iterator"); + } + + if (!m_current_node->IsTerminal()) { + m_stack.push(m_current_node->GetChildren().begin()); + } + + while (!m_stack.empty()) { + auto &top_it = m_stack.top(); + auto end_it = top_it.m_owner->GetChildren().end(); + if (top_it != end_it) { + m_current_node = *top_it; + ++top_it; + return *this; + } + m_stack.pop(); + } + + m_current_node = nullptr; + return *this; + } + + bool operator==(const iterator &other) const + { + return m_owner == other.m_owner && m_current_node == other.m_current_node; + } + bool operator!=(const iterator &other) const { return !(*this == other); } + }; + + /// Default constructor to support declaration by other modules (e.g. Cython) + Nodes() = default; + + /// Constructor for the Nodes range. + explicit Nodes(Game p_owner) : m_owner(p_owner) {} + + /// Returns an iterator to the first node (the root). + iterator begin() const + { + return (m_owner) ? iterator{m_owner, m_owner->GetRoot()} : iterator{}; + } + + /// Returns an iterator to the past-the-end position. + iterator end() const { return (m_owner) ? iterator{m_owner} : iterator{}; } + }; + /// @name Lifecycle //@{ /// Clean up the game @@ -724,6 +814,8 @@ class GameRep : public BaseGameRep { //@{ /// Returns the root node of the game virtual GameNode GetRoot() const = 0; + /// Returns a range that can be used to iterate over the nodes of the game + Nodes GetNodes() const { return Nodes(this); } /// Returns the number of nodes in the game virtual size_t NumNodes() const = 0; /// Returns the number of non-terminal nodes in the game diff --git a/src/pygambit/gambit.pxd b/src/pygambit/gambit.pxd index 55e573680..affe12f0f 100644 --- a/src/pygambit/gambit.pxd +++ b/src/pygambit/gambit.pxd @@ -241,6 +241,15 @@ cdef extern from "games/game.h": iterator begin() except + iterator end() except + + cppclass Nodes: + cppclass iterator: + bint operator ==(iterator) + bint operator !=(iterator) + c_GameNode operator *() + iterator operator++() + iterator begin() except + + iterator end() except + + int IsTree() except + string GetTitle() except + @@ -264,6 +273,7 @@ cdef extern from "games/game.h": int NumNodes() except + int NumNonterminalNodes() except + c_GameNode GetRoot() except + + Nodes GetNodes() except + c_GameStrategy GetStrategy(int) except +IndexError c_GameStrategy NewStrategy(c_GamePlayer, string) except + diff --git a/src/pygambit/game.pxi b/src/pygambit/game.pxi index a7ddaab2d..529ca0d7b 100644 --- a/src/pygambit/game.pxi +++ b/src/pygambit/game.pxi @@ -184,13 +184,12 @@ class GameNodes: return self.game.deref().NumNodes() def __iter__(self) -> typing.Iterator[Node]: - def dfs(node): - yield node - for child in node.children: - yield from dfs(child) + """Iterate over the game nodes in the depth-first traversal order.""" if not self.game.deref().IsTree(): return - yield from dfs(Node.wrap(self.game.deref().GetRoot())) + + for node in self.game.deref().GetNodes(): + yield Node.wrap(node) @cython.cclass diff --git a/tests/test_node.py b/tests/test_node.py index c61b56745..771592fd6 100644 --- a/tests/test_node.py +++ b/tests/test_node.py @@ -1,3 +1,4 @@ +import itertools import typing import pytest @@ -787,3 +788,30 @@ def test_node_plays(): } # paths=[0, 1], [0, 1, 1], [1, 1, 1] assert set(test_node.plays) == expected_set_of_plays + + +@pytest.mark.parametrize( + "game_obj", + [ + pytest.param(games.read_from_file("basic_extensive_game.efg")), + pytest.param(games.read_from_file("binary_3_levels_generic_payoffs.efg")), + pytest.param(games.read_from_file("cent3.efg")), + pytest.param(games.read_from_file("e01.efg")), + pytest.param(games.read_from_file("e02.efg")), + pytest.param(games.read_from_file("poker.efg")), + pytest.param(gbt.Game.new_tree()), + ], +) +def test_nodes_iteration_order(game_obj: gbt.Game): + """ + Verify that the C++ `game.nodes` iterator produces the DFS traversal. + + """ + def dfs(node: gbt.Node) -> typing.Iterator[gbt.Node]: + yield node + for child in node.children: + yield from dfs(child) + + zipped_nodes = itertools.zip_longest(game_obj.nodes, dfs(game_obj.root)) + + assert all(a == b for a, b in zipped_nodes) From dcf29a5bab33c3872039717a63058aca09a1b8a7 Mon Sep 17 00:00:00 2001 From: drdkad Date: Wed, 9 Jul 2025 06:49:24 +0100 Subject: [PATCH 2/2] Improve tests, add two methods to ElementCollection --- src/games/game.h | 19 +++++++++++-------- tests/test_node.py | 8 ++------ 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/games/game.h b/src/games/game.h index ec23f1db8..5c4689f52 100644 --- a/src/games/game.h +++ b/src/games/game.h @@ -62,7 +62,6 @@ template class ElementCollection { public: class iterator { - friend class GameRep; P m_owner{nullptr}; const std::vector *m_container{nullptr}; @@ -106,6 +105,7 @@ template class ElementCollection { return *this; } value_type operator*() const { return m_container->at(m_index); } + P GetOwner() const { return m_owner; } }; ElementCollection() = default; @@ -121,6 +121,8 @@ template class ElementCollection { { return m_owner == p_other.m_owner && m_container == p_other.m_container; } + + bool empty() const { return size() == 0; } size_t size() const { return m_container->size(); } GameObjectPtr front() const { return m_container->front(); } GameObjectPtr back() const { return m_container->back(); } @@ -525,7 +527,7 @@ class GameRep : public BaseGameRep { Game m_owner{nullptr}; GameNode m_current_node{nullptr}; - std::stack m_stack{}; + std::stack> m_stack{}; iterator(Game game) : m_owner(game) {} @@ -561,15 +563,16 @@ class GameRep : public BaseGameRep { } if (!m_current_node->IsTerminal()) { - m_stack.push(m_current_node->GetChildren().begin()); + auto children = m_current_node->GetChildren(); + m_stack.emplace(children.begin(), children.end()); } while (!m_stack.empty()) { - auto &top_it = m_stack.top(); - auto end_it = top_it.m_owner->GetChildren().end(); - if (top_it != end_it) { - m_current_node = *top_it; - ++top_it; + auto &[current_it, end_it] = m_stack.top(); + + if (current_it != end_it) { + m_current_node = *current_it; + ++current_it; return *this; } m_stack.pop(); diff --git a/tests/test_node.py b/tests/test_node.py index 771592fd6..4104e0e64 100644 --- a/tests/test_node.py +++ b/tests/test_node.py @@ -803,15 +803,11 @@ def test_node_plays(): ], ) def test_nodes_iteration_order(game_obj: gbt.Game): - """ - Verify that the C++ `game.nodes` iterator produces the DFS traversal. - + """Verify that the C++ `game.nodes` iterator produces the DFS traversal. """ def dfs(node: gbt.Node) -> typing.Iterator[gbt.Node]: yield node for child in node.children: yield from dfs(child) - zipped_nodes = itertools.zip_longest(game_obj.nodes, dfs(game_obj.root)) - - assert all(a == b for a, b in zipped_nodes) + assert all(a == b for a, b in itertools.zip_longest(game_obj.nodes, dfs(game_obj.root)))