Skip to content

Commit eff5acc

Browse files
committed
test_qre_solver along with LOGIT_BRANCH_CASES and LOGIT_LAMBDA_CASES
1 parent aa761a5 commit eff5acc

2 files changed

Lines changed: 80 additions & 21 deletions

File tree

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ markers = [
8787
"nash_logit_behavior: tests of logit_solve in behavior strategies",
8888
"nash_gnm_strategy: tests of gnm_solve in mixed strategies",
8989
"nash_ipa_strategy: tests of lpa_solve in mixed strategies",
90+
"qre_logit: tests of logit_solve and related methods",
91+
"qre_logit_lambda: tests of logit_solve_lambda",
92+
"qre_logit_branch: tests of logit_solve_branch",
9093
"nash: all tests of Nash equilibrium solvers",
9194
"slow: all time-consuming tests",
9295
]

tests/test_nash.py

Lines changed: 77 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,23 @@ class EquilibriumTestCase:
3535
"""Summarising the data relevant for a test fixture of a call to an equilibrium solver."""
3636

3737
factory: typing.Callable[[], gbt.Game]
38-
solver: typing.Callable[[gbt.Game], gbt.nash.NashComputationResult]
38+
solver: typing.Callable[[gbt.Game], gbt.qre.LogitQREMixedStrategyFitResult]
3939
expected: list
4040
regret_tol: float | gbt.Rational = Q(0)
4141
prob_tol: float | gbt.Rational = Q(0)
4242

4343

44+
@dataclasses.dataclass
45+
class QREquilibriumTestCase:
46+
"""Summarising the data relevant for a test fixture of a call to an QRE solver."""
47+
48+
factory: typing.Callable[[], gbt.Game]
49+
solver: typing.Callable[[gbt.Game], gbt.nash.NashComputationResult]
50+
expected: list
51+
prob_tol: float
52+
lam_tol: float
53+
54+
4455
##################################################################################################
4556
# NASH SOLVER IN PURE/MIXED STRATEGIES (as opposed to pure/mixed behaviors)
4657
##################################################################################################
@@ -2410,27 +2421,72 @@ def test_simpdiv_strategy():
24102421
##################################################################################################
24112422

24122423

2413-
# Needs a new solver tester
2414-
def test_logit_solve_lambda():
2415-
game = games.read_from_file("const_sum_game.nfg")
2416-
assert (
2417-
len(gbt.qre.logit_solve_lambda(game=game, lam=[1, 2, 3], first_step=0.2, max_accel=1)) > 0
2418-
)
2419-
# [LogitQREMixedStrategyProfile(lam=1.000000,profile=[[0.6429793593274791, 0.3570206406725209],
2420-
# [0.588319024552166, 0.41168097544783405]]),
2421-
# LogitQREMixedStrategyProfile(lam=2.000000,profile=[[0.7726766071376159, 0.2273233928623842],
2422-
# [0.6117434791999494, 0.38825652080005063]]),
2423-
# LogitQREMixedStrategyProfile(lam=3.000000,profile=[[0.859536709259968, 0.14046329074003203],
2424-
# [0.6038157860344706, 0.39618421396552944]])]
2425-
2426-
2427-
def test_logit_solve_branch():
2428-
game = games.read_from_file("const_sum_game.nfg")
2429-
assert (
2430-
len(gbt.qre.logit_solve_branch(game=game, maxregret=0.2, first_step=0.2, max_accel=1)) > 0
2431-
)
2424+
LOGIT_BRANCH_CASES = [
2425+
pytest.param(
2426+
QREquilibriumTestCase(
2427+
factory=functools.partial(games.read_from_file, "const_sum_game.nfg"),
2428+
solver=functools.partial(
2429+
gbt.qre.logit_solve_branch, maxregret=0.2, first_step=0.2, max_accel=1
2430+
),
2431+
expected=[{"idx": 0, "lam": 0, "profile": [d(0.5, 0.5), d(0.5, 0.5)]}],
2432+
prob_tol=TOL_LARGE,
2433+
lam_tol=TOL_LARGE,
2434+
),
2435+
marks=pytest.mark.qre_logit,
2436+
id="test_logit_branch_1",
2437+
),
2438+
]
2439+
2440+
2441+
LOGIT_LAMBDA_CASES = [
2442+
pytest.param(
2443+
QREquilibriumTestCase(
2444+
factory=functools.partial(games.read_from_file, "const_sum_game.nfg"),
2445+
solver=functools.partial(
2446+
gbt.qre.logit_solve_lambda, lam=[1, 2, 3], first_step=0.2, max_accel=1
2447+
),
2448+
expected=[
2449+
{"idx": 0, "lam": 1, "profile": [d(0.643, 0.357), d(0.5883, 0.41168)]},
2450+
{"idx": 1, "lam": 2, "profile": [d(0.7727, 0.2273), d(0.6117, 0.3883)]},
2451+
{"idx": 2, "lam": 3, "profile": [d(0.8595, 0.1405), d(0.6038, 0.39618)]},
2452+
],
2453+
prob_tol=TOL_LARGE,
2454+
lam_tol=TOL_LARGE,
2455+
),
2456+
marks=pytest.mark.qre_logit,
2457+
id="test_logit_lambda_1",
2458+
),
2459+
]
2460+
2461+
2462+
CASES = []
2463+
CASES += LOGIT_BRANCH_CASES
2464+
CASES += LOGIT_LAMBDA_CASES
2465+
24322466

2433-
# [LogitQREMixedStrategyProfile(lam=0.000000,profile=[[0.5, 0.5], [0.5, 0.5]])]
2467+
@pytest.mark.nash
2468+
@pytest.mark.parametrize("test_case", CASES, ids=lambda c: c.label)
2469+
def test_qre_solver(test_case: QREquilibriumTestCase, subtests) -> None:
2470+
"""Test calls of QRE solvers.
2471+
2472+
Subtests:
2473+
- Expected value of lambda for given idx,
2474+
difference in lambda not more than `test_case.lam_tol`
2475+
- Expected profile for given idx and lambda,
2476+
difference in probabilities is no more than `test_case.prob_tol`
2477+
"""
2478+
game = test_case.factory()
2479+
result = test_case.solver(game)
2480+
2481+
for i, exp in enumerate(test_case.expected):
2482+
found = result[exp["idx"]]
2483+
with subtests.test(eq=i, check="lambda"):
2484+
assert abs(exp["lam"] - found.lam) <= test_case.lam_tol
2485+
with subtests.test(eq=i, check="strategy_profile"):
2486+
exp_profile = game.mixed_strategy_profile(rational=True, data=exp["profile"])
2487+
for player in game.players:
2488+
for s in player.strategies:
2489+
assert abs(found.profile[s] - exp_profile[s]) <= test_case.prob_tol
24342490

24352491

24362492
##################################################################################################

0 commit comments

Comments
 (0)