diff --git a/src/main/java/core/AbstractGameState.java b/src/main/java/core/AbstractGameState.java index 2a37b1f45..4f5cdc88d 100644 --- a/src/main/java/core/AbstractGameState.java +++ b/src/main/java/core/AbstractGameState.java @@ -681,6 +681,24 @@ public int hashCode() { result = 31 * result + Arrays.hashCode(playerResults); return result; } + + /** + * Returns the reward signal for reinforcement-learning style environments (e.g., PyTAG). + *

+ * By default, this method returns the game score for the specified player + * (see {@link #getGameScore(int)}). Games may override this method to provide + * an alternative or intermediate reward signal (e.g., phase-based or shaped rewards) + * while leaving the official game score unchanged. + *

+ * If overridden, PyTAG will use this method as the reward signal instead of the + * default score-based reward. + * + * @param playerId the player for whom the reward is being computed + * @return the reward value for the given player + */ + public double getReward(int playerId) { + return getGameScore(playerId); + } /** * HashCodeArray compiles all necessary hash codes for each individual game state. diff --git a/src/main/java/core/Game.java b/src/main/java/core/Game.java index ab377e686..8e7165d38 100644 --- a/src/main/java/core/Game.java +++ b/src/main/java/core/Game.java @@ -720,4 +720,4 @@ public static void main(String[] args) { } -} +} \ No newline at end of file diff --git a/src/main/java/core/PyTAG.java b/src/main/java/core/PyTAG.java index a16d37ed2..9f1e8a767 100644 --- a/src/main/java/core/PyTAG.java +++ b/src/main/java/core/PyTAG.java @@ -264,7 +264,7 @@ public boolean isDone(){ } public double getReward(){ - return gameState.getGameScore(gameState.getCurrentPlayer()); + return gameState.getReward(gameState.getCurrentPlayer()); } public List getActions(){ @@ -424,6 +424,7 @@ public static void main(String[] args) { // double[] obs = env.getObservationVector(); // String json = env.getObservationJson(); double reward = env.getReward(); + System.out.println("REWARD: " + reward); // System.out.println("at step " + steps + " the reward is " + reward + "player ID " + env.gameState.getCurrentPlayer()); // step the environment diff --git a/src/main/java/games/powergrid/PowerGridGameState.java b/src/main/java/games/powergrid/PowerGridGameState.java index 9172ed2e8..6dc1c83b9 100644 --- a/src/main/java/games/powergrid/PowerGridGameState.java +++ b/src/main/java/games/powergrid/PowerGridGameState.java @@ -2,6 +2,7 @@ import core.AbstractGameState; import core.AbstractParameters; +import core.CoreConstants; import core.components.Component; import core.components.Deck; import core.interfaces.IGamePhase; @@ -833,6 +834,73 @@ public double getGameScore(int playerId) { return cities + (money / scale); } + /** + * Computes the game score for the given player, used for intermediate + * reward shaping and final outcome scoring. + * + *

After the Bureaucracy phase, the score is calculated based on a + * a weighted combination of: + *

+ * + *

At game end: + *

+ * + * @param playerId the player whose score is being calculated + * @return a reward value reflecting current progress or final result + */ + @Override + public double getReward(int playerId) { + final PowerGridGamePhase phase = (PowerGridGamePhase) getGamePhase(); + + final double wCities = 0.8; + final double wIncome = 1.2; + + + //if the player wins at the end they get a reward else they get penalized + if (this.getGameStatus() == CoreConstants.GameResult.GAME_END) { + CoreConstants.GameResult[] res = getPlayerResults(); + CoreConstants.GameResult r = (res != null && playerId < res.length) ? res[playerId] : null; + if (r == CoreConstants.GameResult.WIN_GAME) { + return 4.0; + } else { + return -2.0; + } + } + + double base = 0.0; + if (phase == PowerGridGamePhase.BUREAUCRACY) { + int cityTarget = 0; + try { + PowerGridParameters params = (PowerGridParameters) getGameParameters(); + if (params != null && params.citiesToTriggerEnd != null && + params.citiesToTriggerEnd.length >= getNPlayers()) { + cityTarget = params.citiesToTriggerEnd[getNPlayers() - 1]; + } + } catch (Exception ignored) {} + if (cityTarget <= 0) cityTarget = Math.max(1, getMaxCitiesOwned()); + + int cities = getCityCountByPlayer(playerId); + + double normCitiesOwned = clamp01((double) cities / cityTarget); + double normPoweredCities = clamp01((double) poweredCities[playerId]/20); + + base = (wCities * normCitiesOwned + wIncome * normPoweredCities ); + } + + + return base; + + } + private static double clamp01(double x) { + return x < 0 ? 0 : (x > 1 ? 1 : x); + } +