Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ Here is a list of all files in the `src` folder and their purposes:
- **`TicTacToe/Agent.py`**: Defines the base agent class for the game.
- **`TicTacToe/DeepQAgent.py`**: Implements a deep Q-learning agent.
- **`TicTacToe/Display.py`**: Handles the display of the game board.
- **`TicTacToe/DisplayTest.py`**: Contains tests for the display module.
- **`TicTacToe/EquivariantNN.py`**: Implements equivariant neural networks for symmetry-aware learning.
- **`TicTacToe/Evaluation.py`**: Provides evaluation metrics for agents.
- **`TicTacToe/game_types.py`**: Defines types and constants used in the game.
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,4 @@ exclude = ["**/*.ipynb"]
omit = [
# omit this single file
"src/TicTacToe/Evaluation.py",
"src/TicTacToe/DisplayTest.py",
]
24 changes: 15 additions & 9 deletions src/TicTacToe/DeepQAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,14 @@ def _init_symmetrized_loss(self, params: dict[str, Any]) -> None:
lambda x: np.flipud(np.transpose(x)),
lambda x: np.flipud(np.fliplr(np.transpose(x))),
]
if params.get("symmetrized_loss", True):
if params.get("symmetrized_loss", True) and params.get("replay_buffer_type", "uniform") == "uniform":
self.compute_loss = self.create_symmetrized_loss(
self.compute_standard_loss, self.transformations, self.rows
)
elif params.get("symmetrized_loss", True) and params.get("replay_buffer_type", "uniform") == "prioritized":
self.compute_loss = self.create_symmetrized_loss(
self.compute_prioritized_loss, self.transformations, self.rows
)
elif params.get("replay_buffer_type", "uniform") == "prioritized":
self.compute_loss = self.compute_prioritized_loss
else:
Expand Down Expand Up @@ -403,7 +407,6 @@ def compute_standard_loss(
The computed loss.
"""
states, actions, rewards, next_states, dones = samples
# print(f"states.shape = {states.shape}")
q_values = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)
next_q_values = self.target_network(next_states).max(1, keepdim=True)[0].squeeze(1)
targets = rewards + (~dones) * self.gamma * next_q_values
Expand Down Expand Up @@ -585,7 +588,6 @@ def get_best_action(self, board: Board, q_network: nn.Module) -> Action:
"""
state = self.board_to_state(board)
state_tensor = torch.FloatTensor(state).to(self.device)
# print(f"state_tensor.shape = {state_tensor.shape}")
with torch.no_grad():
q_values = q_network(state_tensor).squeeze()
max_q, _ = torch.max(q_values, dim=0)
Expand All @@ -607,10 +609,8 @@ class DeepQPlayingAgent(Agent):

def __init__(self,
q_network: nn.Module | str,
player: Player = "X",
switching: bool = False,
device : str = "cpu",
state_shape: str = "flat") -> None:
params: dict
) -> None:
"""
Initialize the DeepQPlayingAgent.

Expand All @@ -619,6 +619,12 @@ def __init__(self,
player: The player symbol ("X" or "O").
switching: Whether to switch players after each game.
"""
player = params["player"]
switching = params["switching"]
device = params["device"]
state_shape = params["state_shape"]
rows = params["rows"]

super().__init__(player=player, switching=switching)
self.device = torch.device(device)

Expand All @@ -631,9 +637,9 @@ def __init__(self,
if state_shape == "flat":
self.state_converter = FlatStateConverter()
elif state_shape == "2D":
self.state_converter = GridStateConverter(shape=(3, 3)) # Assuming a 3x3 grid
self.state_converter = GridStateConverter(shape=(rows, rows))
elif state_shape == "one-hot":
self.state_converter = OneHotStateConverter(rows=3) # Assuming a 3x3 grid
self.state_converter = OneHotStateConverter(rows=rows)
else:
raise ValueError(f"Unsupported state shape: {state_shape}")

Expand Down
17 changes: 0 additions & 17 deletions src/TicTacToe/DisplayTest.py

This file was deleted.

11 changes: 7 additions & 4 deletions src/TicTacToe/Evaluation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import wandb
import copy

from typing import Any, Optional

Expand Down Expand Up @@ -267,12 +268,13 @@ def evaluate_performance(
dict[str, float]: A dictionary containing evaluation metrics.
"""
wandb_logging = params["wandb_logging"]
device = params["device"]
state_shape = params["state_shape"]
evaluation_batch_size = params["evaluation_batch_size"]

playing_params = copy.deepcopy(params)

q_network1 = learning_agent1.q_network
playing_agent1 = DeepQPlayingAgent(q_network1, player="X", switching=False, device=device, state_shape=state_shape)
playing_params["player"] = "X"
playing_agent1 = DeepQPlayingAgent(q_network1, params=playing_params)
random_agent2 = RandomAgent(player="O", switching=False)
all_data = {}

Expand All @@ -294,7 +296,8 @@ def evaluate_performance(
wandb.log(data)

q_network2 = learning_agent2.q_network
playing_agent2 = DeepQPlayingAgent(q_network2, player="O", switching=False, device=device, state_shape=state_shape)
playing_params["player"] = "O"
playing_agent2 = DeepQPlayingAgent(q_network2, params=playing_params)
random_agent1 = RandomAgent(player="X", switching=False)

game = TicTacToe(random_agent1, playing_agent2, display=None, params=params)
Expand Down
20 changes: 10 additions & 10 deletions tests/test_deep_q_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,14 @@ def get_board(self):
class TestDeepQPlayingAgent(unittest.TestCase):
def setUp(self):
self.q_network = MockQNetwork()
self.agent = DeepQPlayingAgent(q_network=self.q_network, player="X", switching=True)
self.params = {
"player": "X",
"switching": False,
"rows": 3,
"device": "cpu",
"state_shape": "flat",
}
self.agent = DeepQPlayingAgent(q_network=self.q_network, params=self.params)

def test_board_to_state(self):
board = ["X", " ", "O", " ", " ", "X", "O", " ", " "]
Expand All @@ -493,7 +500,7 @@ def test_choose_action(self):

@patch("torch.load", return_value=MockQNetwork())
def test_q_network_loading(self, mock_load):
agent = DeepQPlayingAgent(q_network="mock_path.pth", player="X", switching=False)
agent = DeepQPlayingAgent(q_network="mock_path.pth", params=self.params)
self.assertIsInstance(agent.q_network, MockQNetwork)
mock_load.assert_called_once_with("mock_path.pth", weights_only=False)

Expand All @@ -507,11 +514,4 @@ def test_get_action_done(self):
mock_game = MockTicTacToe()
state_transition = (None, None, True) # Done flag is True.
action = self.agent.get_action(state_transition, mock_game)
self.assertEqual(action, -1) # Game is over, no action taken.

def test_on_game_end(self):
initial_player = self.agent.player
initial_opponent = self.agent.opponent
self.agent.on_game_end(None) # Pass None for game, not used.
self.assertEqual(self.agent.player, initial_opponent)
self.assertEqual(self.agent.opponent, initial_player)
self.assertEqual(action, -1) # Game is over, no action taken.
19 changes: 14 additions & 5 deletions train_and_play/play_O_against_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,23 @@
relative_folder = (script_dir / '../models/all_models').resolve()
model_path = f"{relative_folder}/q_network_3x3x3_O.pth"

params = {
"player": "X", # Player symbol for the agent
"rows": 3, # Board size (rows x rows)
"win_length": 3, # Number of in-a-row needed to win
"rewards": {
"W": 1.0, # Reward for a win
"L": -1.0, # Reward for a loss
"D": 0.5, # Reward for a draw
},
}

# Set up the game
rows = 3
win_length = 3
agent1 = MouseAgent(player="O")
agent2 = DeepQPlayingAgent(q_network=model_path, player="X")
display = ScreenDisplay(rows=rows, cols=rows, waiting_time=0.5)
agent2 = DeepQPlayingAgent(q_network=model_path, params=params)
display = ScreenDisplay(rows=params["rows"], cols=params["rows"], waiting_time=0.5)

game = TicTacToe(agent1, agent2, display=display, rows=rows, cols=rows, win_length=win_length)
game = TicTacToe(agent1, agent2, display=display, params=params)

# Play the game
game.play()
19 changes: 14 additions & 5 deletions train_and_play/play_X_against_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,23 @@
relative_folder = (script_dir / '../models/all_models').resolve()
model_path = f"{relative_folder}/q_network_3x3x3_X.pth" # Change this path to the desired model

params = {
"player": "O", # Player symbol for the agent
"rows": 3, # Board size (rows x rows)
"win_length": 3, # Number of in-a-row needed to win
"rewards": {
"W": 1.0, # Reward for a win
"L": -1.0, # Reward for a loss
"D": 0.5, # Reward for a draw
},
}

# Set up the game
rows = 3
win_length = 3
agent1 = DeepQPlayingAgent(q_network=model_path, player="O")
agent1 = DeepQPlayingAgent(q_network=model_path, params=params)
agent2 = MouseAgent(player="X")
display = ScreenDisplay(rows=rows, cols=rows, waiting_time=0.5)
display = ScreenDisplay(rows=params["rows"], cols=params["rows"], waiting_time=0.5)

game = TicTacToe(agent1, agent2, display=display, rows=rows, cols=rows, win_length=win_length, periodic=True)
game = TicTacToe(agent1, agent2, display=display, params=params)

# Play the game
game.play()
4 changes: 2 additions & 2 deletions train_and_play/train_dqn_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
"shared_replay_buffer": False, # Share replay buffer between agents

# Q Network settings
"network_type": "CNN", # Network architecture: 'Equivariant', 'FullyCNN', 'FCN', 'CNN'
"network_type": "FullyCNN", # Network architecture: 'Equivariant', 'FullyCNN', 'FCN', 'CNN'
"periodic": False, # Periodic boundary conditions
"load_network": False, # Whether to load pretrained weights
"project_name": "TicTacToe", # Weights & Biases project name
Expand All @@ -88,7 +88,7 @@

# --- Sweep Setup ---
# param_sweep = {"replay_buffer_type": ["prioritized", "uniform"], "periodic": [True, False], "state_shape": ["one-hot", "flat"]}
param_sweep = {"periodic": [True, False], "state_shape": ["one-hot", "2D", "flat"], "network_type": ["CNN", "FullyCNN"]}
param_sweep = {"replay_buffer_type": ["prioritized", "uniform"], "symmetrized_loss": [True, False], "state_shape": ["one-hot", "flat"]}
sweep_combinations, param_keys = get_param_sweep_combinations(param_sweep)

# --- Shared Replay Buffer Setup ---
Expand Down