Skip to content

Commit 0b370da

Browse files
committed
Implement Node class, allowing for iteration over game nodes in the DFT order. Change Python API to use the new nodes iterator. Add tests.
1 parent b017e1a commit 0b370da

5 files changed

Lines changed: 136 additions & 5 deletions

File tree

ChangeLog

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
prescribed by a strategy at an information set
1717
- Tests for creation of the reduced strategic form from an extensive-form game (currently only
1818
for games with perfect recall)
19+
- Implement a `Nodes` class as a member of `GameRep`. This class includes an iterator
20+
that starts from the root node and successively returns nodes in depth-first traversal order (#530)
1921

2022

2123
## [16.3.1] - unreleased

src/games/game.h

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
#include <list>
2727
#include <set>
28+
#include <stack>
2829

2930
#include "number.h"
3031
#include "gameobject.h"
@@ -61,6 +62,8 @@ template <class P, class T> class ElementCollection {
6162

6263
public:
6364
class iterator {
65+
friend class GameRep;
66+
6467
P m_owner{nullptr};
6568
const std::vector<T *> *m_container{nullptr};
6669
size_t m_index{0};
@@ -512,6 +515,93 @@ class GameRep : public BaseGameRep {
512515
using Players = ElementCollection<Game, GamePlayerRep>;
513516
using Outcomes = ElementCollection<Game, GameOutcomeRep>;
514517

518+
class Nodes {
519+
Game m_owner{nullptr};
520+
521+
public:
522+
class iterator {
523+
friend class Nodes;
524+
using ChildIterator = ElementCollection<GameNode, GameNodeRep>::iterator;
525+
526+
Game m_owner{nullptr};
527+
GameNode m_current_node{nullptr};
528+
std::stack<ChildIterator> m_stack{};
529+
530+
iterator(Game game) : m_owner(game) {}
531+
532+
public:
533+
using iterator_category = std::forward_iterator_tag;
534+
using value_type = GameNode;
535+
using pointer = value_type *;
536+
537+
iterator() = default;
538+
539+
iterator(Game game, GameNode start_node) : m_owner(game), m_current_node(start_node)
540+
{
541+
if (!start_node) {
542+
return;
543+
}
544+
if (start_node->GetGame() != m_owner) {
545+
throw MismatchException();
546+
}
547+
}
548+
549+
value_type operator*() const
550+
{
551+
if (!m_current_node) {
552+
throw std::runtime_error("Cannot dereference an end iterator");
553+
}
554+
return m_current_node;
555+
}
556+
557+
iterator &operator++()
558+
{
559+
if (!m_current_node) {
560+
throw std::out_of_range("Cannot increment an end iterator");
561+
}
562+
563+
if (!m_current_node->IsTerminal()) {
564+
m_stack.push(m_current_node->GetChildren().begin());
565+
}
566+
567+
while (!m_stack.empty()) {
568+
auto &top_it = m_stack.top();
569+
auto end_it = top_it.m_owner->GetChildren().end();
570+
if (top_it != end_it) {
571+
m_current_node = *top_it;
572+
++top_it;
573+
return *this;
574+
}
575+
m_stack.pop();
576+
}
577+
578+
m_current_node = nullptr;
579+
return *this;
580+
}
581+
582+
bool operator==(const iterator &other) const
583+
{
584+
return m_owner == other.m_owner && m_current_node == other.m_current_node;
585+
}
586+
bool operator!=(const iterator &other) const { return !(*this == other); }
587+
};
588+
589+
/// Default constructor to support declaration by other modules (e.g. Cython)
590+
Nodes() = default;
591+
592+
/// Constructor for the Nodes range.
593+
explicit Nodes(Game p_owner) : m_owner(p_owner) {}
594+
595+
/// Returns an iterator to the first node (the root).
596+
iterator begin() const
597+
{
598+
return (m_owner) ? iterator{m_owner, m_owner->GetRoot()} : iterator{};
599+
}
600+
601+
/// Returns an iterator to the past-the-end position.
602+
iterator end() const { return (m_owner) ? iterator{m_owner} : iterator{}; }
603+
};
604+
515605
/// @name Lifecycle
516606
//@{
517607
/// Clean up the game
@@ -724,6 +814,8 @@ class GameRep : public BaseGameRep {
724814
//@{
725815
/// Returns the root node of the game
726816
virtual GameNode GetRoot() const = 0;
817+
/// Returns a range that can be used to iterate over the nodes of the game
818+
Nodes GetNodes() const { return Nodes(this); }
727819
/// Returns the number of nodes in the game
728820
virtual size_t NumNodes() const = 0;
729821
/// Returns the number of non-terminal nodes in the game

src/pygambit/gambit.pxd

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,15 @@ cdef extern from "games/game.h":
241241
iterator begin() except +
242242
iterator end() except +
243243

244+
cppclass Nodes:
245+
cppclass iterator:
246+
bint operator ==(iterator)
247+
bint operator !=(iterator)
248+
c_GameNode operator *()
249+
iterator operator++()
250+
iterator begin() except +
251+
iterator end() except +
252+
244253
int IsTree() except +
245254

246255
string GetTitle() except +
@@ -264,6 +273,7 @@ cdef extern from "games/game.h":
264273
int NumNodes() except +
265274
int NumNonterminalNodes() except +
266275
c_GameNode GetRoot() except +
276+
Nodes GetNodes() except +
267277

268278
c_GameStrategy GetStrategy(int) except +IndexError
269279
c_GameStrategy NewStrategy(c_GamePlayer, string) except +

src/pygambit/game.pxi

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -184,13 +184,12 @@ class GameNodes:
184184
return self.game.deref().NumNodes()
185185

186186
def __iter__(self) -> typing.Iterator[Node]:
187-
def dfs(node):
188-
yield node
189-
for child in node.children:
190-
yield from dfs(child)
187+
"""Iterate over the game nodes in the depth-first traversal order."""
191188
if not self.game.deref().IsTree():
192189
return
193-
yield from dfs(Node.wrap(self.game.deref().GetRoot()))
190+
191+
for node in self.game.deref().GetNodes():
192+
yield Node.wrap(node)
194193

195194

196195
@cython.cclass

tests/test_node.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
import typing
23

34
import pytest
@@ -787,3 +788,30 @@ def test_node_plays():
787788
} # paths=[0, 1], [0, 1, 1], [1, 1, 1]
788789

789790
assert set(test_node.plays) == expected_set_of_plays
791+
792+
793+
@pytest.mark.parametrize(
794+
"game_obj",
795+
[
796+
pytest.param(games.read_from_file("basic_extensive_game.efg")),
797+
pytest.param(games.read_from_file("binary_3_levels_generic_payoffs.efg")),
798+
pytest.param(games.read_from_file("cent3.efg")),
799+
pytest.param(games.read_from_file("e01.efg")),
800+
pytest.param(games.read_from_file("e02.efg")),
801+
pytest.param(games.read_from_file("poker.efg")),
802+
pytest.param(gbt.Game.new_tree()),
803+
],
804+
)
805+
def test_nodes_iteration_order(game_obj: gbt.Game):
806+
"""
807+
Verify that the C++ `game.nodes` iterator produces the DFS traversal.
808+
809+
"""
810+
def dfs(node: gbt.Node) -> typing.Iterator[gbt.Node]:
811+
yield node
812+
for child in node.children:
813+
yield from dfs(child)
814+
815+
zipped_nodes = itertools.zip_longest(game_obj.nodes, dfs(game_obj.root))
816+
817+
assert all(a == b for a, b in zipped_nodes)

0 commit comments

Comments
 (0)