Skip to content

Commit b20240a

Browse files
switch kwargs for args in games()
1 parent b6e371a commit b20240a

File tree

2 files changed

+42
-44
lines changed

2 files changed

+42
-44
lines changed

catalog/__init__.py

Lines changed: 41 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,20 @@ def load(slug: str) -> gbt.Game:
3737
raise FileNotFoundError(f"No catalog entry called {slug}")
3838

3939

40-
def games(**kwargs) -> pd.DataFrame:
40+
def games(
41+
n_actions: int | None = None,
42+
n_contingencies: int | None = None,
43+
n_infosets: int | None = None,
44+
is_const_sum: bool | None = None,
45+
is_perfect_recall: bool | None = None,
46+
is_tree: bool | None = None,
47+
min_payoff: float | None = None,
48+
max_payoff: float | None = None,
49+
n_nodes: int | None = None,
50+
n_outcomes: int | None = None,
51+
n_players: int | None = None,
52+
n_strategies: int | None = None,
53+
) -> pd.DataFrame:
4154
"""
4255
List games available in the package catalog.
4356
@@ -81,58 +94,39 @@ def games(**kwargs) -> pd.DataFrame:
8194
"""
8295
records: list[dict[str, Any]] = []
8396

84-
# Raise an error if any invalid filter keys are provided
85-
valid_filter_keys = {
86-
"n_actions",
87-
"n_contingencies",
88-
"n_infosets",
89-
"is_const_sum",
90-
"is_perfect_recall",
91-
"is_tree",
92-
"min_payoff",
93-
"max_payoff",
94-
"n_nodes",
95-
"n_outcomes",
96-
"n_players",
97-
"n_strategies",
98-
}
99-
for key in kwargs:
100-
if key not in valid_filter_keys:
101-
raise ValueError(f"Invalid kwarg: {key}. Valid kwargs are: {valid_filter_keys}")
102-
10397
def check_filters(game: gbt.Game) -> bool:
104-
if "n_actions" in kwargs:
98+
if n_actions is not None:
10599
if not game.is_tree:
106100
return False
107-
if len(game.actions) != kwargs["n_actions"]:
101+
if len(game.actions) != n_actions:
108102
return False
109-
if "n_contingencies" in kwargs and len(game.contingencies) != kwargs["n_contingencies"]:
103+
if n_contingencies is not None and len(game.contingencies) != n_contingencies:
110104
return False
111-
if "n_infosets" in kwargs:
105+
if n_infosets is not None:
112106
if not game.is_tree:
113107
return False
114-
if len(game.infosets) != kwargs["n_infosets"]:
108+
if len(game.infosets) != n_infosets:
115109
return False
116-
if "is_const_sum" in kwargs and game.is_const_sum != kwargs["is_const_sum"]:
110+
if is_const_sum is not None and game.is_const_sum != is_const_sum:
117111
return False
118-
if "is_perfect_recall" in kwargs and game.is_perfect_recall != kwargs["is_perfect_recall"]:
112+
if is_perfect_recall is not None and game.is_perfect_recall != is_perfect_recall:
119113
return False
120-
if "is_tree" in kwargs and game.is_tree != kwargs["is_tree"]:
114+
if is_tree is not None and game.is_tree != is_tree:
121115
return False
122-
if "min_payoff" in kwargs and game.min_payoff < kwargs["min_payoff"]:
116+
if min_payoff is not None and game.min_payoff < min_payoff:
123117
return False
124-
if "max_payoff" in kwargs and game.max_payoff > kwargs["max_payoff"]:
118+
if max_payoff is not None and game.max_payoff > max_payoff:
125119
return False
126-
if "n_nodes" in kwargs:
120+
if n_nodes is not None:
127121
if not game.is_tree:
128122
return False
129-
if len(game.nodes) != kwargs["n_nodes"]:
123+
if len(game.nodes) != n_nodes:
130124
return False
131-
if "n_outcomes" in kwargs and len(game.outcomes) != kwargs["n_outcomes"]:
125+
if n_outcomes is not None and len(game.outcomes) != n_outcomes:
132126
return False
133-
if "n_players" in kwargs and len(game.players) != kwargs["n_players"]:
127+
if n_players is not None and len(game.players) != n_players:
134128
return False
135-
return not ("n_strategies" in kwargs and len(game.strategies) != kwargs["n_strategies"])
129+
return not (n_strategies is not None and len(game.strategies) != n_strategies)
136130

137131
# Add all the games stored as EFG/NFG files
138132
for resource_path in sorted(_CATALOG_RESOURCE.rglob("*")):
@@ -147,10 +141,12 @@ def check_filters(game: gbt.Game) -> bool:
147141
with as_file(resource_path) as path:
148142
game = reader(str(path))
149143
if check_filters(game):
150-
records.append({
151-
"Game": slug,
152-
"Title": game.title,
153-
})
144+
records.append(
145+
{
146+
"Game": slug,
147+
"Title": game.title,
148+
}
149+
)
154150

155151
# Add all the games from families
156152
for slug, game in family_games().items():
@@ -160,10 +156,12 @@ def check_filters(game: gbt.Game) -> bool:
160156
f"Slug collision: {slug} is present in both file-based and family games."
161157
)
162158
if check_filters(game):
163-
records.append({
164-
"Game": slug,
165-
"Title": game.title,
166-
})
159+
records.append(
160+
{
161+
"Game": slug,
162+
"Title": game.title,
163+
}
164+
)
167165

168166
return pd.DataFrame.from_records(records, columns=["Game", "Title"])
169167

tests/test_catalog.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,5 +130,5 @@ def test_catalog_games_filter_n_strategies(all_games):
130130

131131
def test_catalog_games_filter_bad_filter():
132132
"""Test games() function raises error on invalid filter key"""
133-
with pytest.raises(ValueError):
133+
with pytest.raises(TypeError):
134134
gbt.catalog.games(invalid_filter=123)

0 commit comments

Comments
 (0)