diff --git a/README.md b/README.md index e5697e0..3784000 100644 --- a/README.md +++ b/README.md @@ -95,7 +95,6 @@ Here is a list of all files in the `src` folder and their purposes: - **`TicTacToe/Agent.py`**: Defines the base agent class for the game. - **`TicTacToe/DeepQAgent.py`**: Implements a deep Q-learning agent. - **`TicTacToe/Display.py`**: Handles the display of the game board. -- **`TicTacToe/DisplayTest.py`**: Contains tests for the display module. - **`TicTacToe/EquivariantNN.py`**: Implements equivariant neural networks for symmetry-aware learning. - **`TicTacToe/Evaluation.py`**: Provides evaluation metrics for agents. - **`TicTacToe/game_types.py`**: Defines types and constants used in the game. diff --git a/pyproject.toml b/pyproject.toml index fd565e6..09ae2ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,5 +81,4 @@ exclude = ["**/*.ipynb"] omit = [ # omit this single file "src/TicTacToe/Evaluation.py", - "src/TicTacToe/DisplayTest.py", ] diff --git a/src/TicTacToe/DeepQAgent.py b/src/TicTacToe/DeepQAgent.py index a50283d..4870aa4 100644 --- a/src/TicTacToe/DeepQAgent.py +++ b/src/TicTacToe/DeepQAgent.py @@ -303,10 +303,14 @@ def _init_symmetrized_loss(self, params: dict[str, Any]) -> None: lambda x: np.flipud(np.transpose(x)), lambda x: np.flipud(np.fliplr(np.transpose(x))), ] - if params.get("symmetrized_loss", True): + if params.get("symmetrized_loss", True) and params.get("replay_buffer_type", "uniform") == "uniform": self.compute_loss = self.create_symmetrized_loss( self.compute_standard_loss, self.transformations, self.rows ) + elif params.get("symmetrized_loss", True) and params.get("replay_buffer_type", "uniform") == "prioritized": + self.compute_loss = self.create_symmetrized_loss( + self.compute_prioritized_loss, self.transformations, self.rows + ) elif params.get("replay_buffer_type", "uniform") == "prioritized": self.compute_loss = self.compute_prioritized_loss else: @@ -403,7 +407,6 @@ def compute_standard_loss( The computed loss. """ states, actions, rewards, next_states, dones = samples - # print(f"states.shape = {states.shape}") q_values = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1) next_q_values = self.target_network(next_states).max(1, keepdim=True)[0].squeeze(1) targets = rewards + (~dones) * self.gamma * next_q_values @@ -585,7 +588,6 @@ def get_best_action(self, board: Board, q_network: nn.Module) -> Action: """ state = self.board_to_state(board) state_tensor = torch.FloatTensor(state).to(self.device) - # print(f"state_tensor.shape = {state_tensor.shape}") with torch.no_grad(): q_values = q_network(state_tensor).squeeze() max_q, _ = torch.max(q_values, dim=0) @@ -607,10 +609,8 @@ class DeepQPlayingAgent(Agent): def __init__(self, q_network: nn.Module | str, - player: Player = "X", - switching: bool = False, - device : str = "cpu", - state_shape: str = "flat") -> None: + params: dict + ) -> None: """ Initialize the DeepQPlayingAgent. @@ -619,6 +619,12 @@ def __init__(self, player: The player symbol ("X" or "O"). switching: Whether to switch players after each game. """ + player = params["player"] + switching = params["switching"] + device = params["device"] + state_shape = params["state_shape"] + rows = params["rows"] + super().__init__(player=player, switching=switching) self.device = torch.device(device) @@ -631,9 +637,9 @@ def __init__(self, if state_shape == "flat": self.state_converter = FlatStateConverter() elif state_shape == "2D": - self.state_converter = GridStateConverter(shape=(3, 3)) # Assuming a 3x3 grid + self.state_converter = GridStateConverter(shape=(rows, rows)) elif state_shape == "one-hot": - self.state_converter = OneHotStateConverter(rows=3) # Assuming a 3x3 grid + self.state_converter = OneHotStateConverter(rows=rows) else: raise ValueError(f"Unsupported state shape: {state_shape}") diff --git a/src/TicTacToe/DisplayTest.py b/src/TicTacToe/DisplayTest.py deleted file mode 100644 index bb58b4a..0000000 --- a/src/TicTacToe/DisplayTest.py +++ /dev/null @@ -1,17 +0,0 @@ -# %% -from TicTacToe.Agent import MouseAgent -from TicTacToe.DeepQAgent import DeepQPlayingAgent -from TicTacToe.Display import ScreenDisplay -from TicTacToe.TicTacToe import TicTacToe - -rows = 4 -win_length = 4 -# agent1 = RandomAgent(player='X', switching=False) -# agent1 = HumanAgent(player='X') -agent1 = MouseAgent(player="O") -# agent2 = RandomAgent(player='O', switching=False) -# agent1 = DeepQPlayingAgent(player='X', q_network='models/q_network_4x4x4.pth') -agent2 = DeepQPlayingAgent(player="X", q_network="models/q_network_4x4x4.pth") -# display = ConsoleDisplay(rows=rows, cols=rows, waiting_time=0.5) -display = ScreenDisplay(rows=rows, cols=rows, waiting_time=0.5) -game = TicTacToe(agent1, agent2, display=display, rows=rows, cols=rows, win_length=win_length) diff --git a/src/TicTacToe/Evaluation.py b/src/TicTacToe/Evaluation.py index 6c834ca..d11a1c7 100644 --- a/src/TicTacToe/Evaluation.py +++ b/src/TicTacToe/Evaluation.py @@ -1,4 +1,5 @@ import wandb +import copy from typing import Any, Optional @@ -267,12 +268,13 @@ def evaluate_performance( dict[str, float]: A dictionary containing evaluation metrics. """ wandb_logging = params["wandb_logging"] - device = params["device"] - state_shape = params["state_shape"] evaluation_batch_size = params["evaluation_batch_size"] + playing_params = copy.deepcopy(params) + q_network1 = learning_agent1.q_network - playing_agent1 = DeepQPlayingAgent(q_network1, player="X", switching=False, device=device, state_shape=state_shape) + playing_params["player"] = "X" + playing_agent1 = DeepQPlayingAgent(q_network1, params=playing_params) random_agent2 = RandomAgent(player="O", switching=False) all_data = {} @@ -294,7 +296,8 @@ def evaluate_performance( wandb.log(data) q_network2 = learning_agent2.q_network - playing_agent2 = DeepQPlayingAgent(q_network2, player="O", switching=False, device=device, state_shape=state_shape) + playing_params["player"] = "O" + playing_agent2 = DeepQPlayingAgent(q_network2, params=playing_params) random_agent1 = RandomAgent(player="X", switching=False) game = TicTacToe(random_agent1, playing_agent2, display=None, params=params) diff --git a/tests/test_deep_q_agent.py b/tests/test_deep_q_agent.py index 2a1a3f7..5c5f5d4 100644 --- a/tests/test_deep_q_agent.py +++ b/tests/test_deep_q_agent.py @@ -469,7 +469,14 @@ def get_board(self): class TestDeepQPlayingAgent(unittest.TestCase): def setUp(self): self.q_network = MockQNetwork() - self.agent = DeepQPlayingAgent(q_network=self.q_network, player="X", switching=True) + self.params = { + "player": "X", + "switching": False, + "rows": 3, + "device": "cpu", + "state_shape": "flat", + } + self.agent = DeepQPlayingAgent(q_network=self.q_network, params=self.params) def test_board_to_state(self): board = ["X", " ", "O", " ", " ", "X", "O", " ", " "] @@ -493,7 +500,7 @@ def test_choose_action(self): @patch("torch.load", return_value=MockQNetwork()) def test_q_network_loading(self, mock_load): - agent = DeepQPlayingAgent(q_network="mock_path.pth", player="X", switching=False) + agent = DeepQPlayingAgent(q_network="mock_path.pth", params=self.params) self.assertIsInstance(agent.q_network, MockQNetwork) mock_load.assert_called_once_with("mock_path.pth", weights_only=False) @@ -507,11 +514,4 @@ def test_get_action_done(self): mock_game = MockTicTacToe() state_transition = (None, None, True) # Done flag is True. action = self.agent.get_action(state_transition, mock_game) - self.assertEqual(action, -1) # Game is over, no action taken. - - def test_on_game_end(self): - initial_player = self.agent.player - initial_opponent = self.agent.opponent - self.agent.on_game_end(None) # Pass None for game, not used. - self.assertEqual(self.agent.player, initial_opponent) - self.assertEqual(self.agent.opponent, initial_player) \ No newline at end of file + self.assertEqual(action, -1) # Game is over, no action taken. \ No newline at end of file diff --git a/train_and_play/play_O_against_model.py b/train_and_play/play_O_against_model.py index 5536d92..324bf67 100644 --- a/train_and_play/play_O_against_model.py +++ b/train_and_play/play_O_against_model.py @@ -23,14 +23,23 @@ relative_folder = (script_dir / '../models/all_models').resolve() model_path = f"{relative_folder}/q_network_3x3x3_O.pth" +params = { + "player": "X", # Player symbol for the agent + "rows": 3, # Board size (rows x rows) + "win_length": 3, # Number of in-a-row needed to win + "rewards": { + "W": 1.0, # Reward for a win + "L": -1.0, # Reward for a loss + "D": 0.5, # Reward for a draw + }, +} + # Set up the game -rows = 3 -win_length = 3 agent1 = MouseAgent(player="O") -agent2 = DeepQPlayingAgent(q_network=model_path, player="X") -display = ScreenDisplay(rows=rows, cols=rows, waiting_time=0.5) +agent2 = DeepQPlayingAgent(q_network=model_path, params=params) +display = ScreenDisplay(rows=params["rows"], cols=params["rows"], waiting_time=0.5) -game = TicTacToe(agent1, agent2, display=display, rows=rows, cols=rows, win_length=win_length) +game = TicTacToe(agent1, agent2, display=display, params=params) # Play the game game.play() \ No newline at end of file diff --git a/train_and_play/play_X_against_model.py b/train_and_play/play_X_against_model.py index 3f89b31..8aea075 100644 --- a/train_and_play/play_X_against_model.py +++ b/train_and_play/play_X_against_model.py @@ -23,14 +23,23 @@ relative_folder = (script_dir / '../models/all_models').resolve() model_path = f"{relative_folder}/q_network_3x3x3_X.pth" # Change this path to the desired model +params = { + "player": "O", # Player symbol for the agent + "rows": 3, # Board size (rows x rows) + "win_length": 3, # Number of in-a-row needed to win + "rewards": { + "W": 1.0, # Reward for a win + "L": -1.0, # Reward for a loss + "D": 0.5, # Reward for a draw + }, +} + # Set up the game -rows = 3 -win_length = 3 -agent1 = DeepQPlayingAgent(q_network=model_path, player="O") +agent1 = DeepQPlayingAgent(q_network=model_path, params=params) agent2 = MouseAgent(player="X") -display = ScreenDisplay(rows=rows, cols=rows, waiting_time=0.5) +display = ScreenDisplay(rows=params["rows"], cols=params["rows"], waiting_time=0.5) -game = TicTacToe(agent1, agent2, display=display, rows=rows, cols=rows, win_length=win_length, periodic=True) +game = TicTacToe(agent1, agent2, display=display, params=params) # Play the game game.play() \ No newline at end of file diff --git a/train_and_play/train_dqn_sweep.py b/train_and_play/train_dqn_sweep.py index 9c948e9..8c695f4 100644 --- a/train_and_play/train_dqn_sweep.py +++ b/train_and_play/train_dqn_sweep.py @@ -74,7 +74,7 @@ "shared_replay_buffer": False, # Share replay buffer between agents # Q Network settings - "network_type": "CNN", # Network architecture: 'Equivariant', 'FullyCNN', 'FCN', 'CNN' + "network_type": "FullyCNN", # Network architecture: 'Equivariant', 'FullyCNN', 'FCN', 'CNN' "periodic": False, # Periodic boundary conditions "load_network": False, # Whether to load pretrained weights "project_name": "TicTacToe", # Weights & Biases project name @@ -88,7 +88,7 @@ # --- Sweep Setup --- # param_sweep = {"replay_buffer_type": ["prioritized", "uniform"], "periodic": [True, False], "state_shape": ["one-hot", "flat"]} -param_sweep = {"periodic": [True, False], "state_shape": ["one-hot", "2D", "flat"], "network_type": ["CNN", "FullyCNN"]} +param_sweep = {"replay_buffer_type": ["prioritized", "uniform"], "symmetrized_loss": [True, False], "state_shape": ["one-hot", "flat"]} sweep_combinations, param_keys = get_param_sweep_combinations(param_sweep) # --- Shared Replay Buffer Setup ---