Skip to content

Commit 580fb0d

Browse files
authored
Refactor MixedStrategyProfile cacheing (#715)
This more carefully refactors cacheing of MixedStrategyProfiles: * Introduces a discrete Cache object internally * Ensures all mutations of the profile (and its internal representation) invalidate any cached quantities
1 parent 04a620d commit 580fb0d

11 files changed

Lines changed: 123 additions & 90 deletions

File tree

src/games/game.cc

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ MixedStrategyProfile<T> &MixedStrategyProfile<T>::operator=(const MixedStrategyP
316316
template <class T> Vector<T> MixedStrategyProfile<T>::GetStrategy(const GamePlayer &p_player) const
317317
{
318318
CheckVersion();
319-
auto strategies = m_rep->m_support.GetStrategies(p_player);
319+
auto strategies = m_rep->GetSupport().GetStrategies(p_player);
320320
Vector<T> probs(strategies.size());
321321
std::transform(strategies.begin(), strategies.end(), probs.begin(),
322322
[this](const GameStrategy &s) { return (*m_rep)[s]; });
@@ -326,12 +326,12 @@ template <class T> Vector<T> MixedStrategyProfile<T>::GetStrategy(const GamePlay
326326
template <class T> MixedStrategyProfile<T> MixedStrategyProfile<T>::ToFullSupport() const
327327
{
328328
CheckVersion();
329-
MixedStrategyProfile<T> full(m_rep->m_support.GetGame()->NewMixedStrategyProfile(T(0)));
329+
MixedStrategyProfile<T> full(m_rep->GetSupport().GetGame()->NewMixedStrategyProfile(T(0)));
330330

331-
for (const auto &player : m_rep->m_support.GetGame()->GetPlayers()) {
331+
for (const auto &player : m_rep->GetSupport().GetGame()->GetPlayers()) {
332332
for (const auto &strategy : player->GetStrategies()) {
333333
full[strategy] =
334-
(m_rep->m_support.Contains(strategy)) ? (*m_rep)[strategy] : static_cast<T>(0);
334+
(m_rep->GetSupport().Contains(strategy)) ? (*m_rep)[strategy] : static_cast<T>(0);
335335
}
336336
}
337337
return full;
@@ -343,17 +343,18 @@ template <class T> MixedStrategyProfile<T> MixedStrategyProfile<T>::ToFullSuppor
343343

344344
template <class T> void MixedStrategyProfile<T>::ComputePayoffs() const
345345
{
346-
if (!m_payoffs.empty()) {
347-
// caches (m_payoffs and m_strategyValues) are valid,
348-
// so don't compute anything, simply return
346+
if (m_cache.m_valid) {
349347
return;
350348
}
351-
for (const auto &player : m_rep->m_support.GetPlayers()) {
352-
m_payoffs[player] = GetPayoff(player);
353-
for (const auto &strategy : m_rep->m_support.GetStrategies(player)) {
354-
m_strategyValues[player][strategy] = GetPayoff(strategy);
349+
Cache newCache;
350+
for (const auto &player : m_rep->GetSupport().GetPlayers()) {
351+
newCache.m_payoffs[player] = GetPayoff(player);
352+
for (const auto &strategy : m_rep->GetSupport().GetStrategies(player)) {
353+
newCache.m_strategyValues[player][strategy] = GetPayoff(strategy);
355354
}
356355
}
356+
newCache.m_valid = true;
357+
m_cache = std::move(newCache);
357358
};
358359

359360
template <class T> T MixedStrategyProfile<T>::GetLiapValue() const
@@ -362,12 +363,11 @@ template <class T> T MixedStrategyProfile<T>::GetLiapValue() const
362363
ComputePayoffs();
363364

364365
auto liapValue = static_cast<T>(0);
365-
for (auto p : m_payoffs) {
366-
liapValue += std::transform_reduce(
367-
m_strategyValues.at(p.first).begin(), m_strategyValues.at(p.first).end(),
368-
static_cast<T>(0), std::plus<T>(), [&p](const auto &v) -> T {
369-
return sqr(std::max(v.second - p.second, static_cast<T>(0)));
370-
});
366+
for (const auto &p : m_cache.m_payoffs) {
367+
const auto &values = m_cache.m_strategyValues.at(p.first);
368+
liapValue += sum_function(values, [&](const auto &v) {
369+
return sqr(std::max(v.second - p.second, static_cast<T>(0)));
370+
});
371371
}
372372
return liapValue;
373373
}
@@ -376,13 +376,14 @@ template <class T> T MixedStrategyProfile<T>::GetRegret(const GameStrategy &p_st
376376
{
377377
CheckVersion();
378378
ComputePayoffs();
379+
379380
auto player = p_strategy->GetPlayer();
380381
T best_other_payoff = maximize_function(
381382
filter_if(player->GetStrategies(), [&](const auto &s) { return s != p_strategy; }),
382383
[this, &player](const auto &strategy) -> T {
383-
return m_strategyValues.at(player).at(strategy);
384+
return m_cache.m_strategyValues.at(player).at(strategy);
384385
});
385-
return std::max(best_other_payoff - m_strategyValues.at(player).at(p_strategy),
386+
return std::max(best_other_payoff - m_cache.m_strategyValues.at(player).at(p_strategy),
386387
static_cast<T>(0));
387388
}
388389

@@ -392,9 +393,9 @@ template <class T> T MixedStrategyProfile<T>::GetRegret(const GamePlayer &p_play
392393
ComputePayoffs();
393394
auto br_payoff =
394395
maximize_function(p_player->GetStrategies(), [this, p_player](const auto &strategy) -> T {
395-
return m_strategyValues.at(p_player).at(strategy);
396+
return m_cache.m_strategyValues.at(p_player).at(strategy);
396397
});
397-
return br_payoff - m_payoffs.at(p_player);
398+
return br_payoff - m_cache.m_payoffs.at(p_player);
398399
}
399400

400401
template <class T> T MixedStrategyProfile<T>::GetMaxRegret() const

src/games/gametable.cc

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ Rational TablePureStrategyProfileRep::GetPayoff(const GamePlayer &p_player) cons
103103
Rational TablePureStrategyProfileRep::GetStrategyValue(const GameStrategy &p_strategy) const
104104
{
105105
const auto &player = p_strategy->GetPlayer();
106-
GameOutcomeRep *outcome =
106+
const GameOutcomeRep *outcome =
107107
dynamic_cast<GameTableRep &>(*m_nfg)
108108
.m_results[m_index - m_profile.at(player)->m_offset + p_strategy->m_offset];
109109
if (outcome) {
@@ -159,17 +159,18 @@ std::unique_ptr<MixedStrategyProfileRep<T>> TableMixedStrategyProfileRep<T>::Cop
159159
template <class T>
160160
T TableMixedStrategyProfileRep<T>::GetPayoff(int pl, int index, int current) const
161161
{
162-
if (current > static_cast<int>(this->m_support.GetGame()->NumPlayers())) {
163-
const Game game = this->m_support.GetGame();
162+
if (current > static_cast<int>(this->GetSupport().GetGame()->NumPlayers())) {
163+
const Game game = this->GetSupport().GetGame();
164164
auto &g = dynamic_cast<GameTableRep &>(*game);
165165
if (const auto outcome = g.m_results[index]) {
166-
return outcome->GetPayoff<T>(this->m_support.GetGame()->GetPlayer(pl));
166+
return outcome->GetPayoff<T>(this->GetSupport().GetGame()->GetPlayer(pl));
167167
}
168168
return static_cast<T>(0);
169169
}
170170

171171
T sum = static_cast<T>(0);
172-
for (auto s : this->m_support.GetStrategies(this->m_support.GetGame()->GetPlayer(current))) {
172+
for (auto s :
173+
this->GetSupport().GetStrategies(this->GetSupport().GetGame()->GetPlayer(current))) {
173174
if ((*this)[s] != T(0)) {
174175
sum += ((*this)[s] * GetPayoff(pl, index + s->m_offset, current + 1));
175176
}
@@ -189,15 +190,16 @@ void TableMixedStrategyProfileRep<T>::GetPayoffDeriv(int pl, int const_pl, int c
189190
if (cur_pl == const_pl) {
190191
cur_pl++;
191192
}
192-
if (cur_pl > static_cast<int>(this->m_support.GetGame()->NumPlayers())) {
193-
const Game game = this->m_support.GetGame();
193+
if (cur_pl > static_cast<int>(this->GetSupport().GetGame()->NumPlayers())) {
194+
const Game game = this->GetSupport().GetGame();
194195
auto &g = dynamic_cast<GameTableRep &>(*game);
195196
if (const auto outcome = g.m_results[index]) {
196-
value += prob * outcome->GetPayoff<T>(this->m_support.GetGame()->GetPlayer(pl));
197+
value += prob * outcome->GetPayoff<T>(this->GetSupport().GetGame()->GetPlayer(pl));
197198
}
198199
}
199200
else {
200-
for (auto s : this->m_support.GetStrategies(this->m_support.GetGame()->GetPlayer(cur_pl))) {
201+
for (auto s :
202+
this->GetSupport().GetStrategies(this->GetSupport().GetGame()->GetPlayer(cur_pl))) {
201203
if ((*this)[s] > T(0)) {
202204
GetPayoffDeriv(pl, const_pl, cur_pl + 1, index + s->m_offset, prob * (*this)[s], value);
203205
}
@@ -221,15 +223,16 @@ void TableMixedStrategyProfileRep<T>::GetPayoffDeriv(int pl, int const_pl1, int
221223
while (cur_pl == const_pl1 || cur_pl == const_pl2) {
222224
cur_pl++;
223225
}
224-
if (cur_pl > static_cast<int>(this->m_support.GetGame()->NumPlayers())) {
225-
const Game game = this->m_support.GetGame();
226+
if (cur_pl > static_cast<int>(this->GetSupport().GetGame()->NumPlayers())) {
227+
const Game game = this->GetSupport().GetGame();
226228
auto &g = dynamic_cast<GameTableRep &>(*game);
227229
if (const auto outcome = g.m_results[index]) {
228-
value += prob * outcome->GetPayoff<T>(this->m_support.GetGame()->GetPlayer(pl));
230+
value += prob * outcome->GetPayoff<T>(this->GetSupport().GetGame()->GetPlayer(pl));
229231
}
230232
}
231233
else {
232-
for (auto s : this->m_support.GetStrategies(this->m_support.GetGame()->GetPlayer(cur_pl))) {
234+
for (auto s :
235+
this->GetSupport().GetStrategies(this->GetSupport().GetGame()->GetPlayer(cur_pl))) {
233236
if ((*this)[s] > static_cast<T>(0)) {
234237
GetPayoffDeriv(pl, const_pl1, const_pl2, cur_pl + 1, index + s->m_offset,
235238
prob * (*this)[s], value);
@@ -524,7 +527,7 @@ void GameTableRep::RebuildTable()
524527
IndexStrategies();
525528
}
526529

527-
void GameTableRep::IndexStrategies()
530+
void GameTableRep::IndexStrategies() const
528531
{
529532
long offset = 1L;
530533
for (auto player : m_players) {

src/games/gametable.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class GameTableRep : public GameExplicitRep {
4040

4141
/// @name Private auxiliary functions
4242
//@{
43-
void IndexStrategies();
43+
void IndexStrategies() const;
4444
void RebuildTable();
4545
//@}
4646

src/games/gametree.cc

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,21 +57,20 @@ std::unique_ptr<MixedStrategyProfileRep<T>> TreeMixedStrategyProfileRep<T>::Copy
5757

5858
template <class T> void TreeMixedStrategyProfileRep<T>::MakeBehavior() const
5959
{
60-
if (mixed_behav_profile_sptr.get() == nullptr) {
61-
mixed_behav_profile_sptr =
62-
std::make_shared<MixedBehaviorProfile<T>>(MixedStrategyProfile<T>(Copy()));
60+
if (m_mixedBehavior == nullptr) {
61+
m_mixedBehavior = std::make_shared<MixedBehaviorProfile<T>>(MixedStrategyProfile<T>(Copy()));
6362
}
6463
}
6564

66-
template <class T> void TreeMixedStrategyProfileRep<T>::InvalidateCache() const
65+
template <class T> void TreeMixedStrategyProfileRep<T>::OnProfileChanged() const
6766
{
68-
mixed_behav_profile_sptr = nullptr;
67+
m_mixedBehavior = nullptr;
6968
}
7069

7170
template <class T> T TreeMixedStrategyProfileRep<T>::GetPayoff(int pl) const
7271
{
7372
MakeBehavior();
74-
return mixed_behav_profile_sptr->GetPayoff(mixed_behav_profile_sptr->GetGame()->GetPlayer(pl));
73+
return m_mixedBehavior->GetPayoff(m_mixedBehavior->GetGame()->GetPlayer(pl));
7574
}
7675

7776
template <class T>

src/games/gametree.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,10 @@ template <class T> class TreeMixedStrategyProfileRep : public MixedStrategyProfi
191191
T GetPayoffDeriv(int pl, const GameStrategy &, const GameStrategy &) const override;
192192

193193
private:
194-
mutable std::shared_ptr<MixedBehaviorProfile<T>> mixed_behav_profile_sptr;
194+
mutable std::shared_ptr<MixedBehaviorProfile<T>> m_mixedBehavior;
195195

196196
void MakeBehavior() const;
197-
void InvalidateCache() const override;
197+
void OnProfileChanged() const override;
198198
};
199199

200200
} // namespace Gambit

0 commit comments

Comments
 (0)