diff --git a/cartpole.py b/cartpole.py index d16f116..34091c3 100644 --- a/cartpole.py +++ b/cartpole.py @@ -1,5 +1,5 @@ import random -import gym +import gymnasium import numpy as np from collections import deque from keras.models import Sequential @@ -61,28 +61,34 @@ def experience_replay(self): def cartpole(): - env = gym.make(ENV_NAME) + env = gymnasium.make(ENV_NAME) score_logger = ScoreLogger(ENV_NAME) observation_space = env.observation_space.shape[0] action_space = env.action_space.n dqn_solver = DQNSolver(observation_space, action_space) run = 0 - while True: + + ROUNDS = 100 + for i in range(ROUNDS): run += 1 - state = env.reset() + state = env.reset()[0] + # print([1, observation_space]) + # print(state) state = np.reshape(state, [1, observation_space]) step = 0 while True: step += 1 #env.render() action = dqn_solver.act(state) - state_next, reward, terminal, info = env.step(action) - reward = reward if not terminal else -reward + # print(env.step(action)) + state_next, reward, terminated, truncated, info = env.step(action) + reward = reward if not terminated else -reward state_next = np.reshape(state_next, [1, observation_space]) - dqn_solver.remember(state, action, reward, state_next, terminal) + dqn_solver.remember(state, action, reward, state_next, terminated) state = state_next - if terminal: - print "Run: " + str(run) + ", exploration: " + str(dqn_solver.exploration_rate) + ", score: " + str(step) + if terminated: + print("RUN: ", i) + print ("Run: " + str(run) + ", exploration: " + str(dqn_solver.exploration_rate) + ", score: " + str(step)) score_logger.add_score(step, run) break dqn_solver.experience_replay() diff --git a/requirements.txt b/requirements.txt index 41953f3..f09af62 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ numpy -gym +gymnasium keras matplotlib tensorflow diff --git a/scores/score_logger.py b/scores/score_logger.py index 8538bc7..095b17f 100644 --- a/scores/score_logger.py +++ b/scores/score_logger.py @@ -38,10 +38,10 @@ def add_score(self, score, run): show_legend=True) self.scores.append(score) mean_score = mean(self.scores) - print "Scores: (min: " + str(min(self.scores)) + ", avg: " + str(mean_score) + ", max: " + str(max(self.scores)) + ")\n" + print ("Scores: (min: " + str(min(self.scores)) + ", avg: " + str(mean_score) + ", max: " + str(max(self.scores)) + ")\n") if mean_score >= AVERAGE_SCORE_TO_SOLVE and len(self.scores) >= CONSECUTIVE_RUNS_TO_SOLVE: solve_score = run-CONSECUTIVE_RUNS_TO_SOLVE - print "Solved in " + str(solve_score) + " runs, " + str(run) + " total runs." + print ("Solved in " + str(solve_score) + " runs, " + str(run) + " total runs.") self._save_csv(SOLVED_CSV_PATH, solve_score) self._save_png(input_path=SOLVED_CSV_PATH, output_path=SOLVED_PNG_PATH, @@ -60,8 +60,14 @@ def _save_png(self, input_path, output_path, x_label, y_label, average_of_n_last reader = csv.reader(scores) data = list(reader) for i in range(0, len(data)): - x.append(int(i)) - y.append(int(data[i][0])) + # print(i) + # print(len(data)) + print(data) + if(len(data[i]) > 0): + x.append(int(i)) + y.append(int(data[i][0])) + else: + print("Missing data") plt.subplots() plt.plot(x, y, label="score per run") diff --git a/scores/scores.csv b/scores/scores.csv index 8eadc96..46f7bf8 100644 --- a/scores/scores.csv +++ b/scores/scores.csv @@ -1,112 +1,38 @@ -43 -40 -26 -32 -13 -13 -12 -11 -13 -15 -8 -11 -9 -13 -10 -10 -8 -14 -16 -9 -11 -12 -10 -8 -10 -10 -13 -9 -88 -97 -56 -28 -24 -41 -45 -29 -30 -68 -49 -34 -62 -67 -87 -59 -97 -69 -96 -109 -184 -201 -176 -139 -340 -238 -283 -237 -250 -374 -226 -256 -419 -230 -265 -280 -220 -260 -234 -240 -209 -500 -500 -424 -212 -500 -300 -269 -446 -209 -203 -251 -229 -203 -500 -232 -360 -388 -317 -184 -500 -500 -306 -500 -425 -464 -297 -346 -105 -10 -9 -11 -10 -11 -10 -440 -475 -500 -431 -500 -179 -500 -13 -500 +19 +19 +19 +20 +22 +9 +11 +16 +8 +11 +11 +14 +14 +13 +13 +19 +12 +9 +12 +9 +15 +12 +11 +13 +88 +56 +74 +51 +37 +45 +48 +32 +57 +42 +73 +129 +181 +140 diff --git a/scores/scores.png b/scores/scores.png index 3b7458f..24f2eb3 100644 Binary files a/scores/scores.png and b/scores/scores.png differ