Skip to content

Commit 89cbddc

Browse files
d-kadtturocy
authored andcommitted
Implement Nodes collection for games.
This implements a Nodes collection class in a game, representing the set of all nodes in the game tree (if defined). The default iteration of the set of nodes is defined as depth-first traversal. Closes #530.
1 parent b017e1a commit 89cbddc

5 files changed

Lines changed: 127 additions & 5 deletions

File tree

ChangeLog

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
prescribed by a strategy at an information set
1717
- Tests for creation of the reduced strategic form from an extensive-form game (currently only
1818
for games with perfect recall)
19+
- Implement `Nodes` collection as a member of `GameRep`, including a C++ iterator that
20+
returns nodes in depth-first traversal order (#530)
1921

2022

2123
## [16.3.1] - unreleased

src/games/game.h

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
#include <list>
2727
#include <set>
28+
#include <stack>
2829

2930
#include "number.h"
3031
#include "gameobject.h"
@@ -118,6 +119,8 @@ template <class P, class T> class ElementCollection {
118119
{
119120
return m_owner == p_other.m_owner && m_container == p_other.m_container;
120121
}
122+
123+
bool empty() const { return m_container->empty(); }
121124
size_t size() const { return m_container->size(); }
122125
GameObjectPtr<T> front() const { return m_container->front(); }
123126
GameObjectPtr<T> back() const { return m_container->back(); }
@@ -512,6 +515,88 @@ class GameRep : public BaseGameRep {
512515
using Players = ElementCollection<Game, GamePlayerRep>;
513516
using Outcomes = ElementCollection<Game, GameOutcomeRep>;
514517

518+
class Nodes {
519+
Game m_owner{nullptr};
520+
521+
public:
522+
class iterator {
523+
friend class Nodes;
524+
using ChildIterator = ElementCollection<GameNode, GameNodeRep>::iterator;
525+
526+
Game m_owner{nullptr};
527+
GameNode m_current_node{nullptr};
528+
std::stack<std::pair<ChildIterator, ChildIterator>> m_stack{};
529+
530+
iterator(const Game &game) : m_owner(game) {}
531+
532+
public:
533+
using iterator_category = std::forward_iterator_tag;
534+
using value_type = GameNode;
535+
using pointer = value_type *;
536+
537+
iterator() = default;
538+
539+
iterator(const Game &game, const GameNode &start_node) : m_owner(game), m_current_node(start_node)
540+
{
541+
if (!start_node) {
542+
return;
543+
}
544+
if (start_node->GetGame() != m_owner) {
545+
throw MismatchException();
546+
}
547+
}
548+
549+
value_type operator*() const
550+
{
551+
if (!m_current_node) {
552+
throw std::runtime_error("Cannot dereference an end iterator");
553+
}
554+
return m_current_node;
555+
}
556+
557+
iterator &operator++()
558+
{
559+
if (!m_current_node) {
560+
throw std::out_of_range("Cannot increment an end iterator");
561+
}
562+
563+
if (!m_current_node->IsTerminal()) {
564+
auto children = m_current_node->GetChildren();
565+
m_stack.emplace(children.begin(), children.end());
566+
}
567+
568+
while (!m_stack.empty()) {
569+
auto &[current_it, end_it] = m_stack.top();
570+
571+
if (current_it != end_it) {
572+
m_current_node = *current_it;
573+
++current_it;
574+
return *this;
575+
}
576+
m_stack.pop();
577+
}
578+
579+
m_current_node = nullptr;
580+
return *this;
581+
}
582+
583+
bool operator==(const iterator &other) const
584+
{
585+
return m_owner == other.m_owner && m_current_node == other.m_current_node;
586+
}
587+
bool operator!=(const iterator &other) const { return !(*this == other); }
588+
};
589+
590+
Nodes() = default;
591+
explicit Nodes(const Game &p_owner) : m_owner(p_owner) {}
592+
593+
iterator begin() const
594+
{
595+
return (m_owner) ? iterator{m_owner, m_owner->GetRoot()} : iterator{};
596+
}
597+
iterator end() const { return (m_owner) ? iterator{m_owner} : iterator{}; }
598+
};
599+
515600
/// @name Lifecycle
516601
//@{
517602
/// Clean up the game
@@ -724,6 +809,8 @@ class GameRep : public BaseGameRep {
724809
//@{
725810
/// Returns the root node of the game
726811
virtual GameNode GetRoot() const = 0;
812+
/// Returns a range that can be used to iterate over the nodes of the game
813+
Nodes GetNodes() const { return Nodes(this); }
727814
/// Returns the number of nodes in the game
728815
virtual size_t NumNodes() const = 0;
729816
/// Returns the number of non-terminal nodes in the game

src/pygambit/gambit.pxd

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,15 @@ cdef extern from "games/game.h":
241241
iterator begin() except +
242242
iterator end() except +
243243

244+
cppclass Nodes:
245+
cppclass iterator:
246+
bint operator ==(iterator)
247+
bint operator !=(iterator)
248+
c_GameNode operator *()
249+
iterator operator++()
250+
iterator begin() except +
251+
iterator end() except +
252+
244253
int IsTree() except +
245254

246255
string GetTitle() except +
@@ -264,6 +273,7 @@ cdef extern from "games/game.h":
264273
int NumNodes() except +
265274
int NumNonterminalNodes() except +
266275
c_GameNode GetRoot() except +
276+
Nodes GetNodes() except +
267277

268278
c_GameStrategy GetStrategy(int) except +IndexError
269279
c_GameStrategy NewStrategy(c_GamePlayer, string) except +

src/pygambit/game.pxi

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -184,13 +184,12 @@ class GameNodes:
184184
return self.game.deref().NumNodes()
185185

186186
def __iter__(self) -> typing.Iterator[Node]:
187-
def dfs(node):
188-
yield node
189-
for child in node.children:
190-
yield from dfs(child)
187+
"""Iterate over the game nodes in the depth-first traversal order."""
191188
if not self.game.deref().IsTree():
192189
return
193-
yield from dfs(Node.wrap(self.game.deref().GetRoot()))
190+
191+
for node in self.game.deref().GetNodes():
192+
yield Node.wrap(node)
194193

195194

196195
@cython.cclass

tests/test_node.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
import typing
23

34
import pytest
@@ -787,3 +788,26 @@ def test_node_plays():
787788
} # paths=[0, 1], [0, 1, 1], [1, 1, 1]
788789

789790
assert set(test_node.plays) == expected_set_of_plays
791+
792+
793+
@pytest.mark.parametrize(
794+
"game_obj",
795+
[
796+
pytest.param(games.read_from_file("basic_extensive_game.efg")),
797+
pytest.param(games.read_from_file("binary_3_levels_generic_payoffs.efg")),
798+
pytest.param(games.read_from_file("cent3.efg")),
799+
pytest.param(games.read_from_file("e01.efg")),
800+
pytest.param(games.read_from_file("e02.efg")),
801+
pytest.param(games.read_from_file("poker.efg")),
802+
pytest.param(gbt.Game.new_tree()),
803+
],
804+
)
805+
def test_nodes_iteration_order(game_obj: gbt.Game):
806+
"""Verify that the C++ `game.nodes` iterator produces the DFS traversal.
807+
"""
808+
def dfs(node: gbt.Node) -> typing.Iterator[gbt.Node]:
809+
yield node
810+
for child in node.children:
811+
yield from dfs(child)
812+
813+
assert all(a == b for a, b in itertools.zip_longest(game_obj.nodes, dfs(game_obj.root)))

0 commit comments

Comments
 (0)