Skip to content

Commit 9d1fa01

Browse files
committed
Improve tests, add two methods to ElementCollection
1 parent 0b370da commit 9d1fa01

2 files changed

Lines changed: 6 additions & 8 deletions

File tree

src/games/game.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ template <class P, class T> class ElementCollection {
6262

6363
public:
6464
class iterator {
65-
friend class GameRep;
6665

6766
P m_owner{nullptr};
6867
const std::vector<T *> *m_container{nullptr};
@@ -106,6 +105,7 @@ template <class P, class T> class ElementCollection {
106105
return *this;
107106
}
108107
value_type operator*() const { return m_container->at(m_index); }
108+
P GetOwner() const { return m_owner; }
109109
};
110110

111111
ElementCollection() = default;
@@ -121,6 +121,8 @@ template <class P, class T> class ElementCollection {
121121
{
122122
return m_owner == p_other.m_owner && m_container == p_other.m_container;
123123
}
124+
125+
bool empty() const { return size() == 0; }
124126
size_t size() const { return m_container->size(); }
125127
GameObjectPtr<T> front() const { return m_container->front(); }
126128
GameObjectPtr<T> back() const { return m_container->back(); }
@@ -566,7 +568,7 @@ class GameRep : public BaseGameRep {
566568

567569
while (!m_stack.empty()) {
568570
auto &top_it = m_stack.top();
569-
auto end_it = top_it.m_owner->GetChildren().end();
571+
auto end_it = top_it.GetOwner()->GetChildren().end();
570572
if (top_it != end_it) {
571573
m_current_node = *top_it;
572574
++top_it;

tests/test_node.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -803,15 +803,11 @@ def test_node_plays():
803803
],
804804
)
805805
def test_nodes_iteration_order(game_obj: gbt.Game):
806-
"""
807-
Verify that the C++ `game.nodes` iterator produces the DFS traversal.
808-
806+
"""Verify that the C++ `game.nodes` iterator produces the DFS traversal.
809807
"""
810808
def dfs(node: gbt.Node) -> typing.Iterator[gbt.Node]:
811809
yield node
812810
for child in node.children:
813811
yield from dfs(child)
814812

815-
zipped_nodes = itertools.zip_longest(game_obj.nodes, dfs(game_obj.root))
816-
817-
assert all(a == b for a, b in zipped_nodes)
813+
assert all(a == b for a, b in itertools.zip_longest(game_obj.nodes, dfs(game_obj.root)))

0 commit comments

Comments
 (0)