diff --git a/src/main/java/core/AbstractGameState.java b/src/main/java/core/AbstractGameState.java index 2a37b1f45..4f5cdc88d 100644 --- a/src/main/java/core/AbstractGameState.java +++ b/src/main/java/core/AbstractGameState.java @@ -681,6 +681,24 @@ public int hashCode() { result = 31 * result + Arrays.hashCode(playerResults); return result; } + + /** + * Returns the reward signal for reinforcement-learning style environments (e.g., PyTAG). + *
+ * By default, this method returns the game score for the specified player + * (see {@link #getGameScore(int)}). Games may override this method to provide + * an alternative or intermediate reward signal (e.g., phase-based or shaped rewards) + * while leaving the official game score unchanged. + *
+ * If overridden, PyTAG will use this method as the reward signal instead of the
+ * default score-based reward.
+ *
+ * @param playerId the player for whom the reward is being computed
+ * @return the reward value for the given player
+ */
+ public double getReward(int playerId) {
+ return getGameScore(playerId);
+ }
/**
* HashCodeArray compiles all necessary hash codes for each individual game state.
diff --git a/src/main/java/core/Game.java b/src/main/java/core/Game.java
index ab377e686..8e7165d38 100644
--- a/src/main/java/core/Game.java
+++ b/src/main/java/core/Game.java
@@ -720,4 +720,4 @@ public static void main(String[] args) {
}
-}
+}
\ No newline at end of file
diff --git a/src/main/java/core/PyTAG.java b/src/main/java/core/PyTAG.java
index a16d37ed2..9f1e8a767 100644
--- a/src/main/java/core/PyTAG.java
+++ b/src/main/java/core/PyTAG.java
@@ -264,7 +264,7 @@ public boolean isDone(){
}
public double getReward(){
- return gameState.getGameScore(gameState.getCurrentPlayer());
+ return gameState.getReward(gameState.getCurrentPlayer());
}
public List After the Bureaucracy phase, the score is calculated based on a
+ * a weighted combination of:
+ * At game end:
+ *
+ *
+ *
+ *
+ *
+ *
+ * @param playerId the player whose score is being calculated
+ * @return a reward value reflecting current progress or final result
+ */
+ @Override
+ public double getReward(int playerId) {
+ final PowerGridGamePhase phase = (PowerGridGamePhase) getGamePhase();
+
+ final double wCities = 0.8;
+ final double wIncome = 1.2;
+
+
+ //if the player wins at the end they get a reward else they get penalized
+ if (this.getGameStatus() == CoreConstants.GameResult.GAME_END) {
+ CoreConstants.GameResult[] res = getPlayerResults();
+ CoreConstants.GameResult r = (res != null && playerId < res.length) ? res[playerId] : null;
+ if (r == CoreConstants.GameResult.WIN_GAME) {
+ return 4.0;
+ } else {
+ return -2.0;
+ }
+ }
+
+ double base = 0.0;
+ if (phase == PowerGridGamePhase.BUREAUCRACY) {
+ int cityTarget = 0;
+ try {
+ PowerGridParameters params = (PowerGridParameters) getGameParameters();
+ if (params != null && params.citiesToTriggerEnd != null &&
+ params.citiesToTriggerEnd.length >= getNPlayers()) {
+ cityTarget = params.citiesToTriggerEnd[getNPlayers() - 1];
+ }
+ } catch (Exception ignored) {}
+ if (cityTarget <= 0) cityTarget = Math.max(1, getMaxCitiesOwned());
+
+ int cities = getCityCountByPlayer(playerId);
+
+ double normCitiesOwned = clamp01((double) cities / cityTarget);
+ double normPoweredCities = clamp01((double) poweredCities[playerId]/20);
+
+ base = (wCities * normCitiesOwned + wIncome * normPoweredCities );
+ }
+
+
+ return base;
+
+ }
+ private static double clamp01(double x) {
+ return x < 0 ? 0 : (x > 1 ? 1 : x);
+ }
+