diff --git a/ChangeLog b/ChangeLog index 44821fff3..4554f3288 100644 --- a/ChangeLog +++ b/ChangeLog @@ -9,6 +9,11 @@ - The deprecated functions `Game.read_game`, `Game.parse_game` and `Game.write` functions have been removed as planned. (#357) +### Added +- Added `Node.plays`, `Infoset.plays`, and `Action.plays` properties. + These properties return a list of terminal `Node` objects representing (terminal) plays + consistent with the specific node, information set, or action. + This functionality is backed by new C++ `GameRep::GetPlays()` overloads. ## [16.3.1] - unreleased diff --git a/doc/pygambit.api.rst b/doc/pygambit.api.rst index 3b32c7f02..50052a873 100644 --- a/doc/pygambit.api.rst +++ b/doc/pygambit.api.rst @@ -145,6 +145,7 @@ Information about the game Node.infoset Node.player Node.is_successor_of + Node.plays .. autosummary:: @@ -157,6 +158,7 @@ Information about the game Infoset.actions Infoset.members Infoset.precedes + Infoset.plays .. autosummary:: @@ -166,6 +168,7 @@ Information about the game Action.infoset Action.precedes Action.prob + Action.plays .. autosummary:: diff --git a/src/games/game.h b/src/games/game.h index 41af10140..53663a098 100644 --- a/src/games/game.h +++ b/src/games/game.h @@ -474,6 +474,13 @@ class GameRep : public BaseGameRep { /// Returns the largest payoff to the player in any outcome of the game virtual Rational GetMaxPayoff(const GamePlayer &p_player) const = 0; + /// Returns the set of terminal nodes which are descendants of node + virtual std::vector GetPlays(GameNode node) const { throw UndefinedException(); } + /// Returns the set of terminal nodes which are descendants of members of an infoset + virtual std::vector GetPlays(GameInfoset infoset) const { throw UndefinedException(); } + /// Returns the set of terminal nodes which are descendants of members of an action + virtual std::vector GetPlays(GameAction action) const { throw UndefinedException(); } + /// Returns true if the game is perfect recall. If not, /// a pair of violating information sets is returned in the parameters. virtual bool IsPerfectRecall(GameInfoset &, GameInfoset &) const = 0; diff --git a/src/games/gametree.cc b/src/games/gametree.cc index 711195609..dbb168f86 100644 --- a/src/games/gametree.cc +++ b/src/games/gametree.cc @@ -886,6 +886,7 @@ void GameTreeRep::ClearComputedValues() const } player->m_strategies.clear(); } + const_cast(this)->m_nodePlays.clear(); m_computedValues = false; } @@ -901,6 +902,29 @@ void GameTreeRep::BuildComputedValues() const m_computedValues = true; } +void GameTreeRep::BuildConsistentPlays() +{ + m_nodePlays.clear(); + BuildConsistentPlaysRecursiveImpl(m_root); +} + +std::vector GameTreeRep::BuildConsistentPlaysRecursiveImpl(GameNodeRep *node) +{ + std::vector consistent_plays; + if (node->IsTerminal()) { + consistent_plays = std::vector{node}; + } + else { + for (GameNodeRep *child : node->GetChildren()) { + auto child_consistent_plays = BuildConsistentPlaysRecursiveImpl(child); + consistent_plays.insert(consistent_plays.end(), child_consistent_plays.begin(), + child_consistent_plays.end()); + } + } + m_nodePlays[node] = consistent_plays; + return consistent_plays; +} + //------------------------------------------------------------------------ // GameTreeRep: Writing data files //------------------------------------------------------------------------ @@ -1040,6 +1064,43 @@ Array GameTreeRep::NumInfosets() const // GameTreeRep: Outcomes //------------------------------------------------------------------------ +std::vector GameTreeRep::GetPlays(GameNode node) const +{ + const_cast(this)->BuildConsistentPlays(); + + const std::vector &consistent_plays = m_nodePlays.at(node); + std::vector consistent_plays_copy; + consistent_plays_copy.reserve(consistent_plays.size()); + + std::transform(consistent_plays.cbegin(), consistent_plays.cend(), + std::back_inserter(consistent_plays_copy), + [](GameNodeRep *rep_ptr) -> GameNode { return {rep_ptr}; }); + + return consistent_plays_copy; +} + +std::vector GameTreeRep::GetPlays(GameInfoset infoset) const +{ + std::vector plays; + + for (const GameNode &node : infoset->GetMembers()) { + std::vector member_plays = GetPlays(node); + plays.insert(plays.end(), member_plays.begin(), member_plays.end()); + } + return plays; +} + +std::vector GameTreeRep::GetPlays(GameAction action) const +{ + std::vector plays; + + for (const GameNode &node : action->GetInfoset()->GetMembers()) { + std::vector child_plays = GetPlays(node->GetChild(action)); + plays.insert(plays.end(), child_plays.begin(), child_plays.end()); + } + return plays; +} + void GameTreeRep::DeleteOutcome(const GameOutcome &p_outcome) { IncrementVersion(); diff --git a/src/games/gametree.h b/src/games/gametree.h index 76347dba2..74ffc9865 100644 --- a/src/games/gametree.h +++ b/src/games/gametree.h @@ -38,6 +38,7 @@ class GameTreeRep : public GameExplicitRep { GamePlayerRep *m_chance; std::size_t m_numNodes = 1; std::size_t m_numNonterminalNodes = 0; + std::map> m_nodePlays; /// @name Private auxiliary functions //@{ @@ -50,6 +51,7 @@ class GameTreeRep : public GameExplicitRep { //@{ void Canonicalize(); void BuildComputedValues() const override; + void BuildConsistentPlays(); void ClearComputedValues() const; /// Removes the node from the information set, invalidating if emptied @@ -143,6 +145,10 @@ class GameTreeRep : public GameExplicitRep { void DeleteAction(GameAction) override; void SetOutcome(GameNode, const GameOutcome &p_outcome) override; + std::vector GetPlays(GameNode node) const override; + std::vector GetPlays(GameInfoset infoset) const override; + std::vector GetPlays(GameAction action) const override; + Game CopySubgame(GameNode) const override; //@} @@ -153,6 +159,9 @@ class GameTreeRep : public GameExplicitRep { NewMixedStrategyProfile(double, const StrategySupportProfile &) const override; MixedStrategyProfile NewMixedStrategyProfile(const Rational &, const StrategySupportProfile &) const override; + +private: + std::vector BuildConsistentPlaysRecursiveImpl(GameNodeRep *node); }; template class TreeMixedStrategyProfileRep : public MixedStrategyProfileRep { diff --git a/src/pygambit/action.pxi b/src/pygambit/action.pxi index cf465d122..576834b54 100644 --- a/src/pygambit/action.pxi +++ b/src/pygambit/action.pxi @@ -106,3 +106,12 @@ class Action: return decimal.Decimal(py_string.decode("ascii")) else: return Rational(py_string.decode("ascii")) + + @property + def plays(self) -> typing.List[Node]: + """Returns a list of all terminal `Node` objects consistent with it. + """ + return [ + Node.wrap(n) for n in + self.action.deref().GetInfoset().deref().GetGame().deref().GetPlays(self.action) + ] diff --git a/src/pygambit/gambit.pxd b/src/pygambit/gambit.pxd index 8efd884fe..d2449cdef 100644 --- a/src/pygambit/gambit.pxd +++ b/src/pygambit/gambit.pxd @@ -208,6 +208,9 @@ cdef extern from "games/game.h": c_Rational GetMinPayoff(c_GamePlayer) except + c_Rational GetMaxPayoff() except + c_Rational GetMaxPayoff(c_GamePlayer) except + + stdvector[c_GameNode] GetPlays(c_GameNode) except + + stdvector[c_GameNode] GetPlays(c_GameInfoset) except + + stdvector[c_GameNode] GetPlays(c_GameAction) except + bool IsPerfectRecall() except + c_GameInfoset AppendMove(c_GameNode, c_GamePlayer, int) except +ValueError diff --git a/src/pygambit/game.pxi b/src/pygambit/game.pxi index 734b28456..42509c2bc 100644 --- a/src/pygambit/game.pxi +++ b/src/pygambit/game.pxi @@ -22,7 +22,6 @@ import io import itertools import pathlib -import warnings import numpy as np import scipy.stats diff --git a/src/pygambit/infoset.pxi b/src/pygambit/infoset.pxi index aef0afb9e..552b41d32 100644 --- a/src/pygambit/infoset.pxi +++ b/src/pygambit/infoset.pxi @@ -175,3 +175,11 @@ class Infoset: def player(self) -> Player: """The player who has the move at this information set.""" return Player.wrap(self.infoset.deref().GetPlayer()) + + @property + def plays(self) -> typing.List[Node]: + """Returns a list of all terminal `Node` objects consistent with it. + """ + return [ + Node.wrap(n) for n in self.infoset.deref().GetGame().deref().GetPlays(self.infoset) + ] diff --git a/src/pygambit/node.pxi b/src/pygambit/node.pxi index 78741252e..f8ec2bd29 100644 --- a/src/pygambit/node.pxi +++ b/src/pygambit/node.pxi @@ -212,3 +212,9 @@ class Node: if self.node.deref().GetOutcome() == cython.cast(c_GameOutcome, NULL): return None return Outcome.wrap(self.node.deref().GetOutcome()) + + @property + def plays(self) -> typing.List[Node]: + """Returns a list of all terminal `Node` objects consistent with it. + """ + return [Node.wrap(n) for n in self.node.deref().GetGame().deref().GetPlays(self.node)] diff --git a/tests/test_actions.py b/tests/test_actions.py index 10d55ca35..9b46e7d3d 100644 --- a/tests/test_actions.py +++ b/tests/test_actions.py @@ -127,3 +127,19 @@ def test_action_delete_chance(game: gbt.Game): assert p2 == p1 / (1-old_probs[0]) with pytest.raises(gbt.UndefinedOperationError): game.delete_action(chance_iset.actions[0]) + + +def test_action_plays(): + """Verify `action.plays` returns plays reachable from a given action. + """ + game = games.read_from_file("e01.efg") + list_nodes = list(game.nodes) + list_infosets = list(game.infosets) + + test_action = list_infosets[2].actions[0] # members' paths=[0, 1, 0], [0, 1] + + expected_set_of_plays = { + list_nodes[4], list_nodes[7] + } # paths=[0, 1, 0], [0, 1] + + assert set(test_action.plays) == expected_set_of_plays diff --git a/tests/test_infosets.py b/tests/test_infosets.py index 9d9c96495..a6f5b6d6a 100644 --- a/tests/test_infosets.py +++ b/tests/test_infosets.py @@ -53,3 +53,19 @@ def test_infoset_add_action_error(): game = games.read_from_file("basic_extensive_game.efg") with pytest.raises(gbt.MismatchError): game.add_action(game.players[0].infosets[0], game.players[1].infosets[0].actions[0]) + + +def test_infoset_plays(): + """Verify `infoset.plays` returns plays reachable from a given infoset. + """ + game = games.read_from_file("e01.efg") + list_nodes = list(game.nodes) + list_infosets = list(game.infosets) + + test_infoset = list_infosets[2] # members' paths=[1, 0], [1] + + expected_set_of_plays = { + list_nodes[4], list_nodes[5], list_nodes[7], list_nodes[8] + } # paths=[0, 1, 0], [1, 1, 0], [0, 1], [1, 1] + + assert set(test_infoset.plays) == expected_set_of_plays diff --git a/tests/test_node.py b/tests/test_node.py index b4c25595e..c61b56745 100644 --- a/tests/test_node.py +++ b/tests/test_node.py @@ -772,3 +772,18 @@ def test_nonterminal_len_after_copy_tree(): assert len(game._nonterminal_nodes) == initial_number_of_nodes \ + number_of_nonterminal_src_ancestors + + +def test_node_plays(): + """Verify `node.plays` returns plays reachable from a given node. + """ + game = games.read_from_file("e02.efg") + list_nodes = list(game.nodes) + + test_node = list_nodes[2] # path=[1] + + expected_set_of_plays = { + list_nodes[3], list_nodes[5], list_nodes[6] + } # paths=[0, 1], [0, 1, 1], [1, 1, 1] + + assert set(test_node.plays) == expected_set_of_plays