Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/main/java/core/AbstractGameState.java
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,24 @@ public int hashCode() {
result = 31 * result + Arrays.hashCode(playerResults);
return result;
}

/**
* Returns the reward signal for reinforcement-learning style environments (e.g., PyTAG).
* <p>
* By default, this method returns the game score for the specified player
* (see {@link #getGameScore(int)}). Games may override this method to provide
* an alternative or intermediate reward signal (e.g., phase-based or shaped rewards)
* while leaving the official game score unchanged.
* <p>
* If overridden, PyTAG will use this method as the reward signal instead of the
* default score-based reward.
*
* @param playerId the player for whom the reward is being computed
* @return the reward value for the given player
*/
public double getReward(int playerId) {
return getGameScore(playerId);
}

/**
* HashCodeArray compiles all necessary hash codes for each individual game state.
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/core/Game.java
Original file line number Diff line number Diff line change
Expand Up @@ -720,4 +720,4 @@ public static void main(String[] args) {
}


}
}
3 changes: 2 additions & 1 deletion src/main/java/core/PyTAG.java
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ public boolean isDone(){
}

public double getReward(){
return gameState.getGameScore(gameState.getCurrentPlayer());
return gameState.getReward(gameState.getCurrentPlayer());
}

public List<AbstractAction> getActions(){
Expand Down Expand Up @@ -424,6 +424,7 @@ public static void main(String[] args) {
// double[] obs = env.getObservationVector();
// String json = env.getObservationJson();
double reward = env.getReward();
System.out.println("REWARD: " + reward);
// System.out.println("at step " + steps + " the reward is " + reward + "player ID " + env.gameState.getCurrentPlayer());

// step the environment
Expand Down
68 changes: 68 additions & 0 deletions src/main/java/games/powergrid/PowerGridGameState.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import core.AbstractGameState;
import core.AbstractParameters;
import core.CoreConstants;
import core.components.Component;
import core.components.Deck;
import core.interfaces.IGamePhase;
Expand Down Expand Up @@ -833,6 +834,73 @@ public double getGameScore(int playerId) {

return cities + (money / scale);
}
/**
* Computes the game score for the given player, used for intermediate
* reward shaping and final outcome scoring.
*
* <p>After the Bureaucracy phase, the score is calculated based on a
* a weighted combination of:
* <ul>
* <li>Normalized city count (progress toward game-end city target)</li>
* <li>Normalized income (based on powered cities)</li>
* </ul>
*
* <p>At game end:
* <ul>
* <li>Winners receive +4.0</li>
* <li>All non-winners receive −2.0</li>
* </ul>
*
* @param playerId the player whose score is being calculated
* @return a reward value reflecting current progress or final result
*/
@Override
public double getReward(int playerId) {
final PowerGridGamePhase phase = (PowerGridGamePhase) getGamePhase();

final double wCities = 0.8;
final double wIncome = 1.2;


//if the player wins at the end they get a reward else they get penalized
if (this.getGameStatus() == CoreConstants.GameResult.GAME_END) {
CoreConstants.GameResult[] res = getPlayerResults();
CoreConstants.GameResult r = (res != null && playerId < res.length) ? res[playerId] : null;
if (r == CoreConstants.GameResult.WIN_GAME) {
return 4.0;
} else {
return -2.0;
}
}

double base = 0.0;
if (phase == PowerGridGamePhase.BUREAUCRACY) {
int cityTarget = 0;
try {
PowerGridParameters params = (PowerGridParameters) getGameParameters();
if (params != null && params.citiesToTriggerEnd != null &&
params.citiesToTriggerEnd.length >= getNPlayers()) {
cityTarget = params.citiesToTriggerEnd[getNPlayers() - 1];
}
} catch (Exception ignored) {}
if (cityTarget <= 0) cityTarget = Math.max(1, getMaxCitiesOwned());

int cities = getCityCountByPlayer(playerId);

double normCitiesOwned = clamp01((double) cities / cityTarget);
double normPoweredCities = clamp01((double) poweredCities[playerId]/20);

base = (wCities * normCitiesOwned + wIncome * normPoweredCities );
}


return base;

}
private static double clamp01(double x) {
return x < 0 ? 0 : (x > 1 ? 1 : x);
}




Expand Down