Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
prescribed by a strategy at an information set
- Tests for creation of the reduced strategic form from an extensive-form game (currently only
for games with perfect recall)
- Implement a `Nodes` class as a member of `GameRep`. This class includes an iterator
that starts from the root node and successively returns nodes in depth-first traversal order (#530)


## [16.3.1] - unreleased
Expand Down
95 changes: 95 additions & 0 deletions src/games/game.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

#include <list>
#include <set>
#include <stack>

#include "number.h"
#include "gameobject.h"
Expand Down Expand Up @@ -61,6 +62,7 @@ template <class P, class T> class ElementCollection {

public:
class iterator {

P m_owner{nullptr};
const std::vector<T *> *m_container{nullptr};
size_t m_index{0};
Expand Down Expand Up @@ -103,6 +105,7 @@ template <class P, class T> class ElementCollection {
return *this;
}
value_type operator*() const { return m_container->at(m_index); }
P GetOwner() const { return m_owner; }
};

ElementCollection() = default;
Expand All @@ -118,6 +121,8 @@ template <class P, class T> class ElementCollection {
{
return m_owner == p_other.m_owner && m_container == p_other.m_container;
}

bool empty() const { return size() == 0; }
size_t size() const { return m_container->size(); }
GameObjectPtr<T> front() const { return m_container->front(); }
GameObjectPtr<T> back() const { return m_container->back(); }
Expand Down Expand Up @@ -512,6 +517,94 @@ class GameRep : public BaseGameRep {
using Players = ElementCollection<Game, GamePlayerRep>;
using Outcomes = ElementCollection<Game, GameOutcomeRep>;

class Nodes {
Game m_owner{nullptr};

public:
class iterator {
friend class Nodes;
using ChildIterator = ElementCollection<GameNode, GameNodeRep>::iterator;

Game m_owner{nullptr};
GameNode m_current_node{nullptr};
std::stack<std::pair<ChildIterator, ChildIterator>> m_stack{};

iterator(Game game) : m_owner(game) {}

public:
using iterator_category = std::forward_iterator_tag;
using value_type = GameNode;
using pointer = value_type *;

iterator() = default;

iterator(Game game, GameNode start_node) : m_owner(game), m_current_node(start_node)
{
if (!start_node) {
return;
}
if (start_node->GetGame() != m_owner) {
throw MismatchException();
}
}

value_type operator*() const
{
if (!m_current_node) {
throw std::runtime_error("Cannot dereference an end iterator");
}
return m_current_node;
}

iterator &operator++()
{
if (!m_current_node) {
throw std::out_of_range("Cannot increment an end iterator");
}

if (!m_current_node->IsTerminal()) {
auto children = m_current_node->GetChildren();
m_stack.emplace(children.begin(), children.end());
}

while (!m_stack.empty()) {
auto &[current_it, end_it] = m_stack.top();

if (current_it != end_it) {
m_current_node = *current_it;
++current_it;
return *this;
}
m_stack.pop();
}

m_current_node = nullptr;
return *this;
}

bool operator==(const iterator &other) const
{
return m_owner == other.m_owner && m_current_node == other.m_current_node;
}
bool operator!=(const iterator &other) const { return !(*this == other); }
};

/// Default constructor to support declaration by other modules (e.g. Cython)
Nodes() = default;

/// Constructor for the Nodes range.
explicit Nodes(Game p_owner) : m_owner(p_owner) {}

/// Returns an iterator to the first node (the root).
iterator begin() const
{
return (m_owner) ? iterator{m_owner, m_owner->GetRoot()} : iterator{};
}

/// Returns an iterator to the past-the-end position.
iterator end() const { return (m_owner) ? iterator{m_owner} : iterator{}; }
};

/// @name Lifecycle
//@{
/// Clean up the game
Expand Down Expand Up @@ -724,6 +817,8 @@ class GameRep : public BaseGameRep {
//@{
/// Returns the root node of the game
virtual GameNode GetRoot() const = 0;
/// Returns a range that can be used to iterate over the nodes of the game
Nodes GetNodes() const { return Nodes(this); }
/// Returns the number of nodes in the game
virtual size_t NumNodes() const = 0;
/// Returns the number of non-terminal nodes in the game
Expand Down
10 changes: 10 additions & 0 deletions src/pygambit/gambit.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,15 @@ cdef extern from "games/game.h":
iterator begin() except +
iterator end() except +

cppclass Nodes:
cppclass iterator:
bint operator ==(iterator)
bint operator !=(iterator)
c_GameNode operator *()
iterator operator++()
iterator begin() except +
iterator end() except +

int IsTree() except +

string GetTitle() except +
Expand All @@ -264,6 +273,7 @@ cdef extern from "games/game.h":
int NumNodes() except +
int NumNonterminalNodes() except +
c_GameNode GetRoot() except +
Nodes GetNodes() except +

c_GameStrategy GetStrategy(int) except +IndexError
c_GameStrategy NewStrategy(c_GamePlayer, string) except +
Expand Down
9 changes: 4 additions & 5 deletions src/pygambit/game.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,12 @@ class GameNodes:
return self.game.deref().NumNodes()

def __iter__(self) -> typing.Iterator[Node]:
def dfs(node):
yield node
for child in node.children:
yield from dfs(child)
"""Iterate over the game nodes in the depth-first traversal order."""
if not self.game.deref().IsTree():
return
yield from dfs(Node.wrap(self.game.deref().GetRoot()))

for node in self.game.deref().GetNodes():
yield Node.wrap(node)


@cython.cclass
Expand Down
24 changes: 24 additions & 0 deletions tests/test_node.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import typing

import pytest
Expand Down Expand Up @@ -787,3 +788,26 @@ def test_node_plays():
} # paths=[0, 1], [0, 1, 1], [1, 1, 1]

assert set(test_node.plays) == expected_set_of_plays


@pytest.mark.parametrize(
"game_obj",
[
pytest.param(games.read_from_file("basic_extensive_game.efg")),
pytest.param(games.read_from_file("binary_3_levels_generic_payoffs.efg")),
pytest.param(games.read_from_file("cent3.efg")),
pytest.param(games.read_from_file("e01.efg")),
pytest.param(games.read_from_file("e02.efg")),
pytest.param(games.read_from_file("poker.efg")),
pytest.param(gbt.Game.new_tree()),
],
)
def test_nodes_iteration_order(game_obj: gbt.Game):
"""Verify that the C++ `game.nodes` iterator produces the DFS traversal.
"""
def dfs(node: gbt.Node) -> typing.Iterator[gbt.Node]:
yield node
for child in node.children:
yield from dfs(child)

assert all(a == b for a, b in itertools.zip_longest(game_obj.nodes, dfs(game_obj.root)))
Loading