Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions src/TicTacToe/DeepQAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def state_to_board(self, state: np.ndarray) -> Board:
...

class FlatStateConverter:
"""
Converts board states to flat numpy arrays for neural network input.
"""
def __init__(self, state_to_board_translation=None):
self.state_to_board_translation = state_to_board_translation or {"X": 1, "O": -1, " ": 0}
self.board_to_state_translation = {v: k for k, v in self.state_to_board_translation.items()}
Expand All @@ -47,6 +50,9 @@ def state_to_board(self, state: np.ndarray) -> Board:
return [self.board_to_state_translation[cell] for cell in flat_state]

class GridStateConverter:
"""
Converts board states to 2D grid numpy arrays for neural network input.
"""
def __init__(self, shape: tuple[int, int], state_to_board_translation=None):
self.shape = shape
self.state_to_board_translation = state_to_board_translation or {"X": 1, "O": -1, " ": 0}
Expand All @@ -62,8 +68,7 @@ def state_to_board(self, state: np.ndarray) -> Board:

class OneHotStateConverter:
"""
Converts board states to one-hot encoded numpy arrays of shape (1, 3, rows, rows),
where channels represent 'X', 'O', and empty respectively.
Converts board states to one-hot encoded numpy arrays for neural network input.
"""

def __init__(self, rows: int):
Expand Down Expand Up @@ -100,6 +105,12 @@ def state_to_board(self, state: np.ndarray) -> Board:
class DeepQLearningAgent(Agent, EvaluationMixin):
"""
A Deep Q-Learning agent for playing Tic Tac Toe.

Attributes:
params (dict): Configuration parameters for the agent.
q_network (nn.Module): The Q-network for action-value estimation.
target_network (nn.Module): The target Q-network for stable training.
replay_buffer (ReplayBuffer): The replay buffer for storing experiences.
"""

def __init__(self, params: dict[str, Any]) -> None:
Expand All @@ -121,7 +132,7 @@ def __init__(self, params: dict[str, Any]) -> None:
self._override_with_shared_replay_buffer(params)
self._init_symmetrized_loss(params)
EvaluationMixin.__init__(
self, wandb_enabled=params["wandb_logging"], wandb_logging_frequency=params["wandb_logging_frequency"]
self, wandb_logging=params["wandb_logging"], wandb_logging_frequency=params["wandb_logging_frequency"]
)

def _init_config(self, params: dict[str, Any]) -> None:
Expand All @@ -140,7 +151,7 @@ def _init_config(self, params: dict[str, Any]) -> None:
self.learning_rate = params["learning_rate"]
self.replay_buffer_length = params["replay_buffer_length"]
self.wandb_logging_frequency = params["wandb_logging_frequency"]
self.wandb = params["wandb_logging"]
self.wandb_logging = params["wandb_logging"]
self.episode_count = 0
self.games_moves_count = 0
self.train_step_count = 0
Expand All @@ -158,7 +169,7 @@ def _init_wandb(self, params: dict[str, Any]) -> None:
Args:
params: The configuration dictionary.
"""
if self.wandb:
if self.wandb_logging:
wandb.init(config=params) # type: ignore

def _init_group_matrices(self) -> None:
Expand Down Expand Up @@ -587,7 +598,7 @@ def get_best_action(self, board: Board, q_network: nn.Module) -> Action:

class DeepQPlayingAgent(Agent):
"""
A Deep Q-Playing agent for playing Tic Tac Toe.
A Deep Q-Playing agent for playing Tic Tac Toe using a pretrained Q-network.
"""

def __init__(self,
Expand Down
19 changes: 4 additions & 15 deletions src/TicTacToe/Evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def average_array(array: list[float] | list[int], chunk_size: Optional[int] = No

Args:
array (list[float] | list[int]): The input array of numbers.
chunk_size (Optional[int]): The size of each chunk. If None, it defaults to 1% of the array length.
chunk_size (Optional[int]): The size of each chunk. Defaults to 1% of the array length.

Returns:
list[float]: A list of averaged values for each chunk.
Expand Down Expand Up @@ -195,10 +195,7 @@ def QAgent_plays_against_RandomAgent(
Args:
Q (FullySymmetricMatrix): The Q-matrix of the agent.
player (Player): The player ('X' or 'O') for the Q-learning agent.
nr_of_episodes (int): Number of episodes to simulate.
rows (int): Number of rows in the TicTacToe board.
cols (int): Number of columns in the TicTacToe board.
win_length (int): Number of consecutive marks needed to win.
params (dict): Configuration parameters for the simulation.
"""
nr_of_episodes = params["nr_of_episodes"]
playing_agent1 = QPlayingAgent(Q, player=player, switching=False)
Expand Down Expand Up @@ -232,10 +229,7 @@ def QAgent_plays_against_QAgent(
player1 (Player): The player ('X' or 'O') for the first agent.
Q2 (FullySymmetricMatrix): The Q-matrix of the second agent.
player2 (Player | None): The player ('X' or 'O') for the second agent. Defaults to the opposite of player1.
nr_of_episodes (int): Number of episodes to simulate.
rows (int): Number of rows in the TicTacToe board.
cols (int): Number of columns in the TicTacToe board.
win_length (int): Number of consecutive marks needed to win.
params (dict): Configuration parameters for the simulation.
"""
playing_agent1 = QPlayingAgent(Q1, player=player1, switching=False)
if not player2:
Expand Down Expand Up @@ -267,12 +261,7 @@ def evaluate_performance(
Args:
learning_agent1 (DeepQLearningAgent): The first learning agent.
learning_agent2 (DeepQLearningAgent): The second learning agent.
evaluation_batch_size (int): Number of games to simulate for evaluation.
rows (int): Number of rows in the TicTacToe board.
win_length (int): Number of consecutive marks needed to win.
wandb_logging (bool): Whether to log results to Weights & Biases.
device (str): The device ('cpu' or 'cuda') for computation.
periodic (bool): Whether the board has periodic boundaries.
params (dict): Configuration parameters for evaluation.

Returns:
dict[str, float]: A dictionary containing evaluation metrics.
Expand Down
16 changes: 13 additions & 3 deletions src/TicTacToe/EvaluationMixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,18 @@
import wandb

class EvaluationMixin:
def __init__(self, wandb_enabled: bool, wandb_logging_frequency: int):
self.wandb = wandb_enabled
"""
A mixin class for logging and evaluating training metrics.

Attributes:
wandb (bool): Whether Weights & Biases logging is enabled.
wandb_logging_frequency (int): Frequency of logging metrics to Weights & Biases.
train_step_count (int): Counter for training steps.
episode_count (int): Counter for episodes.
evaluation_data (dict): Dictionary for storing evaluation metrics.
"""
def __init__(self, wandb_logging: bool, wandb_logging_frequency: int):
self.wandb_logging = wandb_logging
self.wandb_logging_frequency = wandb_logging_frequency
self.train_step_count = 0
self.episode_count = 0
Expand All @@ -30,7 +40,7 @@ def maybe_log_metrics(self) -> None:
if self.train_step_count % self.wandb_logging_frequency != 0:
return

if self.wandb:
if self.wandb_logging:
wandb.log({
"loss": self.safe_mean(self.evaluation_data["loss"]),
"action_value": self.safe_mean(self.evaluation_data["action_value"]),
Expand Down
88 changes: 73 additions & 15 deletions src/TicTacToe/QNetworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class QNetwork(nn.Module):
"""
A neural network for approximating the Q-function.
A fully connected neural network for approximating the Q-function.
"""

def __init__(self, input_dim: int, output_dim: int) -> None:
Expand All @@ -31,7 +31,7 @@ def __init__(self, input_dim: int, output_dim: int) -> None:

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the QNetwork.
Perform a forward pass through the QNetwork.

Args:
x: Input tensor.
Expand Down Expand Up @@ -77,7 +77,7 @@ def __init__(self, input_dim: int, rows: int, output_dim: int) -> None:

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the CNNQNetwork.
Perform a forward pass through the CNNQNetwork.

Args:
x: Input tensor of shape (batch_size, input_dim, grid_size, grid_size).
Expand All @@ -93,13 +93,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
else:
raise ValueError(f"Unexpected input shape: {x.shape}")

x = x.view(-1, 1, self.rows, self.rows)
x = self.conv_layers(x)
x = self.fc_layers(x)
return x.view(x.size(0), -1) # shape: (batch_size, rows * cols)

class PeriodicConvBase(nn.Module):
"""
A base convolutional module with periodic padding for symmetry handling.
"""

def __init__(self, input_dim: int = 1, padding_mode: str = 'circular'):
"""
Initialize the PeriodicConvBase.

Args:
input_dim: Number of input channels.
padding_mode: Padding mode for convolutional layers.
"""
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(input_dim, 32, kernel_size=3, stride=1, padding=1, padding_mode=padding_mode),
Expand All @@ -109,21 +119,60 @@ def __init__(self, input_dim: int = 1, padding_mode: str = 'circular'):
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Perform a forward pass through the PeriodicConvBase.

Args:
x: Input tensor.

Returns:
Output tensor after applying the convolutional layers.
"""
return self.encoder(x)


class PeriodicQHead(nn.Module):
"""
A convolutional head for generating Q-values with periodic padding.
"""

def __init__(self, padding_mode: str = 'circular'):
"""
Initialize the PeriodicQHead.

Args:
padding_mode: Padding mode for the convolutional layer.
"""
super().__init__()
self.head = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1, padding_mode=padding_mode)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Perform a forward pass through the PeriodicQHead.

Args:
x: Input tensor.

Returns:
Output tensor with Q-values.
"""
x = self.head(x) # shape: (batch_size, 1, rows, cols)
return x.squeeze(1) # shape: (batch_size, rows, cols)


class FullyConvQNetwork(nn.Module):
"""
A fully convolutional Q-network with periodic padding for symmetry handling.
"""

def __init__(self, input_dim: int = 1, padding_mode: str = 'circular'):
"""
Initialize the FullyConvQNetwork.

Args:
input_dim: Number of input channels.
padding_mode: Padding mode for convolutional layers.
"""
super().__init__()
self.base = PeriodicConvBase(input_dim=input_dim, padding_mode=padding_mode)
self.head = PeriodicQHead(padding_mode=padding_mode)
Expand All @@ -138,6 +187,15 @@ def __init__(self, input_dim: int = 1, padding_mode: str = 'circular'):
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Perform a forward pass through the FullyConvQNetwork.

Args:
x: Input tensor.

Returns:
Output tensor with Q-values.
"""
if x.ndim == 2: # (batch_size, rows*cols)
spatial_dim = int(x.size(1) ** 0.5)
x = x.view(x.size(0), 1, spatial_dim, spatial_dim)
Expand Down Expand Up @@ -289,8 +347,8 @@ def __init__(self, weight_pattern: torch.Tensor, bias_pattern: torch.Tensor):
Initialize the EquivariantLayer.

Args:
weight_pattern (torch.Tensor): Tensor defining the weight tying pattern.
bias_pattern (torch.Tensor): Tensor defining the bias tying pattern.
weight_pattern: Tensor defining the weight tying pattern.
bias_pattern: Tensor defining the bias tying pattern.
"""
super(EquivariantLayer, self).__init__() # type: ignore

Expand Down Expand Up @@ -321,10 +379,10 @@ def _get_nr_of_unique_nonzero_elements(self, pattern: torch.Tensor) -> int:
Get the number of unique non-zero elements in a pattern.

Args:
pattern (torch.Tensor): The input pattern.
pattern: The input pattern.

Returns:
int: The number of unique non-zero elements.
The number of unique non-zero elements.
"""
unique_elements = list(set(pattern.detach().numpy().flatten())) # type: ignore
if 0 in unique_elements:
Expand All @@ -337,10 +395,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
Perform a forward pass through the layer.

Args:
x (torch.Tensor): Input tensor.
x: Input tensor.

Returns:
torch.Tensor: Output tensor after applying the layer.
Output tensor after applying the layer.
"""
weight = torch.zeros_like(self.weight_mask, dtype=torch.float32)
bias = torch.zeros_like(self.bias_mask, dtype=torch.float32)
Expand All @@ -353,16 +411,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

class EquivariantNN(nn.Module):
"""
A neural network with multiple equivariant layers.
A neural network with multiple equivariant layers for symmetry-aware learning.
"""

def __init__(self, groupMatrices: list[Any], ms: Tuple[int, int, int, int] = (1, 5, 5, 1)) -> None:
"""
Initialize the EquivariantNN.

Args:
groupMatrices (list[Any]): List of transformation matrices.
ms (Tuple[int, int, int, int]): Dimensions for each layer.
groupMatrices: List of transformation matrices.
ms: Dimensions for each layer.
"""
super(EquivariantNN, self).__init__() # type: ignore

Expand Down Expand Up @@ -391,9 +449,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
Perform a forward pass through the network.

Args:
x (torch.Tensor): Input tensor.
x: Input tensor.

Returns:
torch.Tensor: Output tensor after applying the network.
Output tensor after applying the network.
"""
return self.fc_equivariant(x)
8 changes: 5 additions & 3 deletions src/TicTacToe/ReplayBuffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

class BaseReplayBuffer:
"""
Base class interface for all replay buffers.
Abstract base class for replay buffers.

Defines the interface for adding experiences, sampling batches, and retrieving buffer size.
"""
def add(self, state: State, action: Action, reward: Reward, next_state: State, done: bool) -> None:
raise NotImplementedError
Expand All @@ -19,7 +21,7 @@ def __len__(self) -> int:

class ReplayBuffer(BaseReplayBuffer):
"""
Standard uniform sampling replay buffer.
A standard replay buffer with uniform sampling.
"""
def __init__(self, size: int, state_shape: Tuple[int, ...], device: str = "cpu") -> None:
self.size = size
Expand Down Expand Up @@ -74,7 +76,7 @@ def __len__(self) -> int:

class PrioritizedReplayBuffer(ReplayBuffer):
"""
Prioritized Experience Replay buffer.
A replay buffer that implements prioritized experience replay.
"""
def __init__(self, size, state_shape, device="cpu", alpha=0.6, beta=0.4):
super().__init__(size, state_shape, device)
Expand Down
Loading