Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions src/core/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ template <class Key, class T> bool contains(const std::map<Key, T> &map, const K
return map.find(key) != map.end();
}

/// @brief A container adaptor which skips over a given value when iterating
template <typename Container, typename T> class exclude_value {
/// @brief A container adaptor which returns only the elements matching the predicate
/// This is intended to look forward to C++20-style ranges
template <typename Container, typename Pred> class filter_if {
public:
using Iter = decltype(std::begin(std::declval<Container &>()));

Expand All @@ -78,10 +79,10 @@ template <typename Container, typename T> class exclude_value {
using reference = typename std::iterator_traits<Iter>::reference;
using pointer = typename std::iterator_traits<Iter>::pointer;

iterator(Iter current, Iter end, const T &value)
: m_current(current), m_end(end), m_value(value)
iterator(Iter current, Iter end, Pred pred)
: m_current(current), m_end(end), m_pred(std::move(pred))
{
skip_if_value();
advance_next_valid();
}

value_type operator*() const { return *m_current; }
Expand All @@ -90,7 +91,7 @@ template <typename Container, typename T> class exclude_value {
iterator &operator++()
{
++m_current;
skip_if_value();
advance_next_valid();
return *this;
}

Expand All @@ -106,25 +107,22 @@ template <typename Container, typename T> class exclude_value {
return a.m_current == b.m_current;
}

friend bool operator!=(const iterator &a, const iterator &b)
{
return a.m_current != b.m_current;
}
friend bool operator!=(const iterator &a, const iterator &b) { return !(a == b); }

private:
Iter m_current, m_end;
T m_value;
Pred m_pred;

void skip_if_value()
void advance_next_valid()
{
while (m_current != m_end && *m_current == m_value) {
while (m_current != m_end && !m_pred(*m_current)) {
++m_current;
}
}
};

exclude_value(const Container &c, const T &value)
: m_begin(c.begin(), c.end(), value), m_end(c.end(), c.end(), value)
filter_if(const Container &c, Pred pred)
: m_begin(c.begin(), c.end(), pred), m_end(c.end(), c.end(), pred)
{
}

Expand Down
9 changes: 5 additions & 4 deletions src/games/game.cc
Original file line number Diff line number Diff line change
Expand Up @@ -377,10 +377,11 @@ template <class T> T MixedStrategyProfile<T>::GetRegret(const GameStrategy &p_st
CheckVersion();
ComputePayoffs();
auto player = p_strategy->GetPlayer();
T best_other_payoff = maximize_function(exclude_value(player->GetStrategies(), p_strategy),
[this, &player](const auto &strategy) -> T {
return m_strategyValues.at(player).at(strategy);
});
T best_other_payoff = maximize_function(
filter_if(player->GetStrategies(), [&](const auto &s) { return s != p_strategy; }),
[this, &player](const auto &strategy) -> T {
return m_strategyValues.at(player).at(strategy);
});
return std::max(best_other_payoff - m_strategyValues.at(player).at(p_strategy),
static_cast<T>(0));
}
Expand Down
12 changes: 12 additions & 0 deletions src/games/game.h
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,9 @@ class GameRep : public std::enable_shared_from_this<GameRep> {
public:
using iterator_category = std::input_iterator_tag;
using value_type = GameNode;
using difference_type = std::ptrdiff_t;
using reference = GameNode;
using pointer = GameNode;

iterator() = default;
iterator(const iterator &) = default;
Expand Down Expand Up @@ -1015,6 +1018,15 @@ class GameRep : public std::enable_shared_from_this<GameRep> {
{
return {std::const_pointer_cast<GameRep>(shared_from_this()), p_traversal};
}
auto GetTerminalNodes() const
{
return filter_if(GetNodes(), [](const auto &node) -> bool { return node->IsTerminal(); });
}
auto GetNonterminalNodes(TraversalOrder p_traversal = TraversalOrder::Preorder) const
{
return filter_if(GetNodes(p_traversal),
[](const auto &node) -> bool { return !node->IsTerminal(); });
}
/// Returns the number of nodes in the game
virtual size_t NumNodes() const = 0;
/// Returns the number of non-terminal nodes in the game
Expand Down
44 changes: 0 additions & 44 deletions src/pygambit/game.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -194,41 +194,6 @@ class GameNodes:
yield Node.wrap(node)


@cython.cclass
class GameNonterminalNodes:
"""Represents the set of nodes in a game."""
game = cython.declare(c_Game)

def __init__(self, *args, **kwargs) -> None:
raise ValueError("Cannot create GameNonterminalNodes outside a Game.")

@staticmethod
@cython.cfunc
def wrap(game: c_Game) -> GameNonterminalNodes:
obj: GameNonterminalNodes = GameNonterminalNodes.__new__(GameNonterminalNodes)
obj.game = game
return obj

def __repr__(self) -> str:
return f"GameNonterminalNodes(game={Game.wrap(self.game)})"

def __len__(self) -> int:
"""The number of non-terminal nodes in the game."""
if not self.game.deref().IsTree():
return 0
return self.game.deref().NumNonterminalNodes()

def __iter__(self) -> typing.Iterator[Node]:
def dfs(node):
if not node.is_terminal:
yield node
for child in node.children:
yield from dfs(child)
if not self.game.deref().IsTree():
return
yield from dfs(Node.wrap(self.game.deref().GetRoot()))


@cython.cclass
class GameOutcomes:
"""Represents the set of outcomes in a game."""
Expand Down Expand Up @@ -745,15 +710,6 @@ class Game:
"""
return GameNodes.wrap(self.game)

@property
def _nonterminal_nodes(self) -> GameNonterminalNodes:
"""The set of non-terminal nodes in the game.

Iteration over this property yields the non-terminal nodes in the order of depth-first
search.
"""
return GameNonterminalNodes.wrap(self.game)

@property
def contingencies(self) -> pygambit.gameiter.Contingencies:
"""An iterator over the contingencies in the game."""
Expand Down
176 changes: 0 additions & 176 deletions tests/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,182 +762,6 @@ def test_len_after_copy_tree():
assert len(game.nodes) == initial_number_of_nodes + number_of_src_ancestors - 1


def test_nonterminal_len_matches_expected_count():
"""Verify `len(game._nonterminal_nodes)` matches expected count
"""
game = games.read_from_file("e01.efg")
expected_nonterminal_node_count = 4

direct_nonterminal_len = len(game._nonterminal_nodes)
assert direct_nonterminal_len == expected_nonterminal_node_count


def test_nonterminal_len_after_delete_tree():
"""Verify `len(game._nonterminal_nodes)` is correct after `delete_tree`.
"""
game = games.read_from_file("e01.efg")
initial_number_of_nonterminal_nodes = len(game._nonterminal_nodes)
list_nodes = list(game.nodes)

root_of_the_deleted_subtree = list_nodes[1]
number_of_deleted_nonterminal_nodes = _count_subtree_nodes(root_of_the_deleted_subtree, False)

game.delete_tree(root_of_the_deleted_subtree)

assert len(game._nonterminal_nodes) == initial_number_of_nonterminal_nodes \
- number_of_deleted_nonterminal_nodes


def test_nonterminal_len_after_delete_parent_of_nonterminal_node():
"""Verify `len(game._nonterminal_nodes)` is correct after `delete_parent`.
"""
game = games.read_from_file("e02.efg")
list_nodes = list(game.nodes)
node_parent_to_delete = list_nodes[4] # path=[1, 1]

initial_number_of_nonterminal_nodes = len(game._nonterminal_nodes)
diff = _count_subtree_nodes(node_parent_to_delete.parent, False) \
- _count_subtree_nodes(node_parent_to_delete, False)

game.delete_parent(node_parent_to_delete)

assert len(game._nonterminal_nodes) == initial_number_of_nonterminal_nodes - diff


def test_nonterminal_len_after_delete_parent_of_terminal_node():
"""Verify `len(game._nonterminal_nodes)` is correct after `delete_parent`.
"""
game = games.read_from_file("e02.efg")
list_nodes = list(game.nodes)
node_parent_to_delete = list_nodes[5] # path=[0, 1, 1]

initial_number_of_nonterminal_nodes = len(game._nonterminal_nodes)
diff = _count_subtree_nodes(node_parent_to_delete.parent, False) \
- _count_subtree_nodes(node_parent_to_delete, False)

game.delete_parent(node_parent_to_delete)

assert len(game._nonterminal_nodes) == initial_number_of_nonterminal_nodes - diff


def test_nonterminal_len_after_append_move():
"""Verify `len(game._nonterminal_nodes)` is correct after `append_move`.
"""
game = games.read_from_file("e01.efg")
initial_number_of_nonterminal_nodes = len(game._nonterminal_nodes)
list_nodes = list(game.nodes)

terminal_node = list_nodes[5] # path=[1, 1, 0]
player = game.players[0]
actions_to_add = ["T", "M", "B"]

game.append_move(terminal_node, player, actions_to_add)

assert len(game._nonterminal_nodes) == initial_number_of_nonterminal_nodes \
+ _count_subtree_nodes(terminal_node, False)


def test_nonterminal_len_after_append_infoset():
"""Verify `len(game._nonterminal_nodes)` is correct after `append_infoset`.
"""
game = games.read_from_file("e02.efg")
initial_number_of_nonterminal_nodes = len(game._nonterminal_nodes)
list_nodes = list(game.nodes)

member_node = list_nodes[2] # path=[1]
infoset_to_modify = member_node.infoset
terminal_node_to_add = list_nodes[6] # path=[1, 1, 1]

game.append_infoset(terminal_node_to_add, infoset_to_modify)

assert len(game._nonterminal_nodes) == initial_number_of_nonterminal_nodes \
+ _count_subtree_nodes(terminal_node_to_add, False)


def test_nonterminal_len_after_add_action():
"""Verify `len(game._nonterminal_nodes)` does not change after `add_action` to an infoset.
"""
game = games.read_from_file("e01.efg")
initial_number_of_nonterminal_nodes = len(game._nonterminal_nodes)

infoset_to_modify = game.infosets[1]

game.add_action(infoset_to_modify)

assert len(game._nonterminal_nodes) == initial_number_of_nonterminal_nodes


def test_nonterminal_len_after_delete_action():
"""Verify `len(game._nonterminal_nodes)` is correct after `delete_action`.
"""
game = games.read_from_file("e02.efg")
initial_number_of_nonterminal_nodes = len(game._nonterminal_nodes)

action_to_delete = game.infosets[0].actions[1]

# Calculate the total number of nodes within all subtrees
# that begin immediately after taking the specified action.
nonterminal_nodes_to_delete = 0
action_nodes = _get_members(action_to_delete)

for subtree_root in action_nodes:
nonterminal_nodes_to_delete += _count_subtree_nodes(subtree_root, False)

game.delete_action(action_to_delete)

assert len(game._nonterminal_nodes) == initial_number_of_nonterminal_nodes \
- nonterminal_nodes_to_delete


def test_nonterminal_len_after_insert_move():
"""Verify `len(game._nonterminal_nodes)` correctly increaces by 1 after `insert_move`.
"""
game = games.read_from_file("e01.efg")
initial_number_of_nonterminal_nodes = len(game._nonterminal_nodes)
list_nodes = list(game.nodes)

node_to_insert_above = list_nodes[3]

player = game.players[1]
num_actions_to_add = 3

game.insert_move(node_to_insert_above, player, num_actions_to_add)

assert len(game._nonterminal_nodes) == initial_number_of_nonterminal_nodes + 1


def test_nonterminal_len_after_insert_infoset():
"""Verify `len(game._nonterminal_nodes)` correctly increaces by 1 after `insert_infoset`.
"""
game = games.read_from_file("e01.efg")
initial_number_of_nonterminal_nodes = len(game._nonterminal_nodes)
list_nodes = list(game.nodes)

member_node = list_nodes[6] # path=[1]
infoset_to_modify = member_node.infoset
node_to_insert_above = list_nodes[7] # path=[0, 1]

game.insert_infoset(node_to_insert_above, infoset_to_modify)

assert len(game._nonterminal_nodes) == initial_number_of_nonterminal_nodes + 1


def test_nonterminal_len_after_copy_tree():
"""Verify `len(game._nonterminal_nodes)` is correct after `copy_tree`.
"""
game = games.read_from_file("e01.efg")
initial_number_of_nodes = len(game._nonterminal_nodes)
list_nodes = list(game.nodes)
src_node = list_nodes[3] # path=[1, 0]
dest_node = list_nodes[2] # path=[0, 0]
number_of_nonterminal_src_ancestors = _count_subtree_nodes(src_node, False)

game.copy_tree(src_node, dest_node)

assert len(game._nonterminal_nodes) == initial_number_of_nodes \
+ number_of_nonterminal_src_ancestors


def test_node_plays():
"""Verify `node.plays` returns plays reachable from a given node.
"""
Expand Down
Loading