Skip to content

Commit 803141c

Browse files
committed
Refactor Nodes class; WIP: dereferencing the iterator in opetor++
1 parent f2aace7 commit 803141c

3 files changed

Lines changed: 87 additions & 74 deletions

File tree

src/games/game.h

Lines changed: 76 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -514,73 +514,94 @@ class GameRep : public BaseGameRep {
514514
using Outcomes = ElementCollection<Game, GameOutcomeRep>;
515515

516516
class Nodes {
517-
private:
518-
using ChildIterator = ElementCollection<GameNode, GameNodeRep>::iterator;
519-
520517
Game m_owner{nullptr};
521-
GameNode m_current_node{nullptr};
522-
std::stack<std::pair<ChildIterator, ChildIterator>> m_stack{};
523-
524-
Nodes(Game game) : m_owner(game) {}
525518

526519
public:
527-
Nodes() = default;
528-
529-
Nodes(Game game, GameNode start_node) : m_owner(game), m_current_node(start_node)
530-
{
531-
if (!start_node) {
532-
return;
533-
}
534-
if (start_node->GetGame() != m_owner) {
535-
throw MismatchException();
520+
class iterator {
521+
friend class Nodes;
522+
using ChildIterator = ElementCollection<GameNode, GameNodeRep>::iterator;
523+
524+
Game m_owner{nullptr};
525+
GameNode m_current_node{nullptr};
526+
std::stack<ChildIterator> m_stack{};
527+
528+
iterator(Game game) : m_owner(game) {}
529+
530+
public:
531+
using iterator_category = std::forward_iterator_tag;
532+
using value_type = GameNode;
533+
using pointer = value_type *;
534+
535+
iterator() = default;
536+
537+
iterator(Game game, GameNode start_node) : m_owner(game), m_current_node(start_node)
538+
{
539+
if (!start_node) {
540+
return;
541+
}
542+
if (start_node->GetGame() != m_owner) {
543+
throw MismatchException();
544+
}
536545
}
537-
if (start_node != m_owner->GetRoot()) {
538-
throw std::invalid_argument("Node iteration can only be initiated from the game's root");
539-
}
540-
}
541546

542-
GameNode operator*() const
543-
{
544-
if (!m_current_node) {
545-
throw std::runtime_error("Cannot dereference an end iterator");
547+
value_type operator*() const
548+
{
549+
if (!m_current_node) {
550+
throw std::runtime_error("Cannot dereference an end iterator");
551+
}
552+
return m_current_node;
546553
}
547-
return m_current_node;
548-
}
549554

550-
Nodes &operator++()
551-
{
552-
if (!m_current_node) {
553-
throw std::out_of_range("Cannot increment an end iterator");
555+
iterator &operator++()
556+
{
557+
if (!m_current_node) {
558+
throw std::out_of_range("Cannot increment an end iterator");
559+
}
560+
561+
auto children = m_current_node->GetChildren();
562+
if (children.size() > 0) {
563+
m_stack.emplace(children.begin());
564+
}
565+
566+
while (!m_stack.empty()) {
567+
try {
568+
*m_stack.top();
569+
break;
570+
}
571+
catch (std::out_of_range) {
572+
m_stack.pop();
573+
}
574+
}
575+
576+
if (m_stack.empty()) {
577+
m_current_node = nullptr;
578+
}
579+
else {
580+
auto &top_it = m_stack.top();
581+
m_current_node = *top_it;
582+
++top_it;
583+
}
584+
return *this;
554585
}
555586

556-
auto children = m_current_node->GetChildren();
557-
if (children.size() > 0) {
558-
m_stack.emplace(children.begin(), children.end());
587+
bool operator==(const iterator &other) const
588+
{
589+
return m_owner == other.m_owner && m_current_node == other.m_current_node;
559590
}
591+
bool operator!=(const iterator &other) const { return !(*this == other); }
592+
};
560593

561-
while (!m_stack.empty() && m_stack.top().first == m_stack.top().second) {
562-
m_stack.pop();
563-
}
564-
565-
if (m_stack.empty()) {
566-
m_current_node = nullptr;
567-
}
568-
else {
569-
auto &top_pair = m_stack.top();
570-
m_current_node = *(top_pair.first);
571-
++(top_pair.first);
572-
}
573-
return *this;
574-
}
594+
/// Default constructor to support declaration by other modules (e.g. Cython)
595+
Nodes() = default;
575596

576-
bool operator==(const Nodes &other) const
577-
{
578-
return m_owner == other.m_owner && m_current_node == other.m_current_node;
579-
}
597+
/// Constructor for the Nodes range.
598+
explicit Nodes(Game p_owner) : m_owner(p_owner) {}
580599

581-
bool operator!=(const Nodes &other) const { return !(*this == other); }
600+
/// Returns an iterator to the first node (the root).
601+
iterator begin() const { return m_owner ? iterator(m_owner, m_owner->GetRoot()) : iterator(); }
582602

583-
friend class GameRep;
603+
/// Returns an iterator to the past-the-end position.
604+
iterator end() const { return m_owner ? iterator(m_owner) : iterator(); }
584605
};
585606

586607
/// @name Lifecycle
@@ -795,6 +816,8 @@ class GameRep : public BaseGameRep {
795816
//@{
796817
/// Returns the root node of the game
797818
virtual GameNode GetRoot() const = 0;
819+
/// Returns the nodes of the game in the DFT order
820+
Nodes GetNodes() const { return Nodes(this); }
798821
/// Returns the number of nodes in the game
799822
virtual size_t NumNodes() const = 0;
800823
/// Returns the number of non-terminal nodes in the game
@@ -807,10 +830,6 @@ class GameRep : public BaseGameRep {
807830
virtual Game SetChanceProbs(const GameInfoset &, const Array<Number> &) = 0;
808831
//@}
809832

810-
/// Node iterators
811-
Nodes begin() const { return {this, GetRoot()}; }
812-
Nodes end() const { return {this}; }
813-
814833
/// Build any computed values anew
815834
virtual void BuildComputedValues() const {}
816835
};

src/pygambit/gambit.pxd

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -242,10 +242,13 @@ cdef extern from "games/game.h":
242242
iterator end() except +
243243

244244
cppclass Nodes:
245-
bint operator ==(Nodes)
246-
bint operator !=(Nodes)
247-
c_GameNode operator *()
248-
Nodes operator++()
245+
cppclass iterator:
246+
bint operator ==(iterator)
247+
bint operator !=(iterator)
248+
c_GameNode operator *()
249+
iterator operator++()
250+
iterator begin() except +
251+
iterator end() except +
249252

250253
int IsTree() except +
251254

@@ -270,6 +273,7 @@ cdef extern from "games/game.h":
270273
int NumNodes() except +
271274
int NumNonterminalNodes() except +
272275
c_GameNode GetRoot() except +
276+
Nodes GetNodes() except +
273277

274278
c_GameStrategy GetStrategy(int) except +IndexError
275279
c_GameStrategy NewStrategy(c_GamePlayer, string) except +
@@ -308,8 +312,6 @@ cdef extern from "games/game.h":
308312
void DeleteAction(c_GameAction) except +ValueError
309313
void SetOutcome(c_GameNode, c_GameOutcome) except +
310314
c_Game SetChanceProbs(c_GameInfoset, Array[c_Number]) except +
311-
Nodes begin() except +
312-
Nodes end() except +
313315

314316
c_PureStrategyProfile NewPureStrategyProfile() # except + doesn't compile
315317
c_MixedStrategyProfile[T] NewMixedStrategyProfile[T](T) # except + doesn't compile

src/pygambit/game.pxi

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -186,20 +186,12 @@ class GameNodes:
186186
return self.game.deref().NumNodes()
187187

188188
def __iter__(self) -> typing.Iterator[Node]:
189-
"""
190-
A generator that efficiently iterates over the game nodes using
191-
the underlying C++ iterator, without using cdef for local variables.
192-
"""
189+
"""Iterate over the game nodes in the depth-first traversal order."""
193190
if not self.game.deref().IsTree():
194191
return
195192

196-
it = self.game.deref().begin()
197-
end_it = self.game.deref().end()
198-
199-
while it != end_it:
200-
yield Node.wrap(dereference(it))
201-
202-
preincrement(it)
193+
for node in self.game.deref().GetNodes():
194+
yield Node.wrap(node)
203195

204196

205197
@cython.cclass

0 commit comments

Comments
 (0)