Skip to content

Commit 265ca94

Browse files
committed
overload GetPlays for Infoset and Action
1 parent b767236 commit 265ca94

8 files changed

Lines changed: 69 additions & 7 deletions

File tree

src/games/game.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,10 @@ class GameRep : public BaseGameRep {
476476

477477
/// Returns the set of terminal nodes which are descendants of node
478478
virtual std::vector<GameNode> GetPlays(GameNode node) const = 0;
479+
/// Returns the set of terminal nodes which are descendants of members of an infoset
480+
virtual std::vector<GameNode> GetPlays(GameInfoset infoset) const = 0;
481+
/// Returns the set of terminal nodes which are descendants of members of an action
482+
virtual std::vector<GameNode> GetPlays(GameAction action) const = 0;
479483

480484
/// Returns true if the game is perfect recall. If not,
481485
/// a pair of violating information sets is returned in the parameters.

src/games/gameagg.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,13 @@ class GameAGGRep : public GameRep {
105105
size_t NumNonterminalNodes() const override { throw UndefinedException(); }
106106
/// Returns the set of terminal nodes which are descendants of node
107107
std::vector<GameNode> GetPlays(GameNode node) const override { throw UndefinedException(); }
108+
/// Returns the set of terminal nodes which are descendants of members of an infoset
109+
std::vector<GameNode> GetPlays(GameInfoset infoset) const override
110+
{
111+
throw UndefinedException();
112+
}
113+
/// Returns the set of terminal nodes which are descendants of members of an action
114+
std::vector<GameNode> GetPlays(GameAction action) const override { throw UndefinedException(); }
108115
//@}
109116

110117
/// @name General data access

src/games/gamebagg.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,13 @@ class GameBAGGRep : public GameRep {
113113
size_t NumNonterminalNodes() const override { throw UndefinedException(); }
114114
/// Returns the set of terminal nodes which are descendants of node
115115
std::vector<GameNode> GetPlays(GameNode node) const override { throw UndefinedException(); }
116+
/// Returns the set of terminal nodes which are descendants of members of an infoset
117+
std::vector<GameNode> GetPlays(GameInfoset infoset) const override
118+
{
119+
throw UndefinedException();
120+
}
121+
/// Returns the set of terminal nodes which are descendants of members of an action
122+
std::vector<GameNode> GetPlays(GameAction action) const override { throw UndefinedException(); }
116123
//@}
117124

118125
/// @name General data access

src/games/gametable.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,13 @@ class GameTableRep : public GameExplicitRep {
9494
size_t NumNonterminalNodes() const override { throw UndefinedException(); }
9595
/// Returns the set of terminal nodes which are descendants of node
9696
std::vector<GameNode> GetPlays(GameNode node) const override { throw UndefinedException(); }
97+
/// Returns the set of terminal nodes which are descendants of members of an infoset
98+
std::vector<GameNode> GetPlays(GameInfoset infoset) const override
99+
{
100+
throw UndefinedException();
101+
}
102+
/// Returns the set of terminal nodes which are descendants of members of an action
103+
std::vector<GameNode> GetPlays(GameAction action) const override { throw UndefinedException(); }
97104
//@}
98105

99106
/// @name Outcomes

src/games/gametree.cc

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,6 +1080,31 @@ std::vector<GameNode> GameTreeRep::GetPlays(GameNode node) const
10801080
return consistent_plays_copy;
10811081
}
10821082

1083+
std::vector<GameNode> GameTreeRep::GetPlays(GameInfoset infoset) const
1084+
{
1085+
std::vector<GameNode> plays;
1086+
auto members = infoset->GetMembers();
1087+
1088+
for (const GameNode &node : members) {
1089+
std::vector<GameNode> member_plays = GetPlays(node);
1090+
plays.insert(plays.end(), member_plays.begin(), member_plays.end());
1091+
}
1092+
return plays;
1093+
}
1094+
1095+
std::vector<GameNode> GameTreeRep::GetPlays(GameAction action) const
1096+
{
1097+
std::vector<GameNode> plays;
1098+
auto infoset = action->GetInfoset();
1099+
auto members = infoset->GetMembers();
1100+
1101+
for (const GameNode &node : members) {
1102+
std::vector<GameNode> child_plays = GetPlays(node->GetChild(action));
1103+
plays.insert(plays.end(), child_plays.begin(), child_plays.end());
1104+
}
1105+
return plays;
1106+
}
1107+
10831108
void GameTreeRep::DeleteOutcome(const GameOutcome &p_outcome)
10841109
{
10851110
IncrementVersion();

src/games/gametree.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ class GameTreeRep : public GameExplicitRep {
146146
void SetOutcome(GameNode, const GameOutcome &p_outcome) override;
147147

148148
std::vector<GameNode> GetPlays(GameNode node) const override;
149+
std::vector<GameNode> GetPlays(GameInfoset infoset) const override;
150+
std::vector<GameNode> GetPlays(GameAction action) const override;
149151

150152
Game CopySubgame(GameNode) const override;
151153
//@}

src/pygambit/gambit.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,8 @@ cdef extern from "games/game.h":
209209
c_Rational GetMaxPayoff() except +
210210
c_Rational GetMaxPayoff(c_GamePlayer) except +
211211
stdvector[c_GameNode] GetPlays(c_GameNode) except +
212+
stdvector[c_GameNode] GetPlays(c_GameInfoset) except +
213+
stdvector[c_GameNode] GetPlays(c_GameAction) except +
212214
bool IsPerfectRecall() except +
213215

214216
c_GameInfoset AppendMove(c_GameNode, c_GamePlayer, int) except +ValueError

src/pygambit/game.pxi

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import io
2323
import itertools
2424
import pathlib
25-
import warnings
2625

2726
import numpy as np
2827
import scipy.stats
@@ -2014,11 +2013,20 @@ class Game:
20142013
raise UndefinedOperationError("Cannot delete the only strategy for a player")
20152014
self.game.deref().DeleteStrategy(resolved_strategy.strategy)
20162015

2017-
def get_plays(self, node: typing.Union[Node, str]) -> typing.List[Node]:
2018-
resolved_node = cython.cast(Node, self._resolve_node(node, "get_plays", "node_obj"))
2016+
def get_plays(self, obj: typing.Union[Node, Infoset, Action]) -> typing.List[Node]:
20192017

2020-
plays = []
2021-
for item in self.game.deref().GetPlays(resolved_node.node):
2022-
plays.append(Node.wrap(item))
2018+
c_plays = cython.declare(stdvector[c_GameNode])
20232019

2024-
return plays
2020+
if isinstance(obj, Node):
2021+
obj = cython.cast(Node, obj)
2022+
c_plays = self.game.deref().GetPlays(obj.node)
2023+
elif isinstance(obj, Infoset):
2024+
obj = cython.cast(Infoset, obj)
2025+
c_plays = self.game.deref().GetPlays(obj.infoset)
2026+
elif isinstance(obj, Action):
2027+
obj = cython.cast(Action, obj)
2028+
c_plays = self.game.deref().GetPlays(obj.action)
2029+
else:
2030+
raise TypeError("The object needs to be either Node, Infoset, or Action")
2031+
2032+
return [Node.wrap(item) for item in c_plays]

0 commit comments

Comments
 (0)