diff --git a/src/TicTacToe/DeepQAgent.py b/src/TicTacToe/DeepQAgent.py index b92913e..d89e59a 100644 --- a/src/TicTacToe/DeepQAgent.py +++ b/src/TicTacToe/DeepQAgent.py @@ -35,6 +35,9 @@ def state_to_board(self, state: np.ndarray) -> Board: ... class FlatStateConverter: + """ + Converts board states to flat numpy arrays for neural network input. + """ def __init__(self, state_to_board_translation=None): self.state_to_board_translation = state_to_board_translation or {"X": 1, "O": -1, " ": 0} self.board_to_state_translation = {v: k for k, v in self.state_to_board_translation.items()} @@ -47,6 +50,9 @@ def state_to_board(self, state: np.ndarray) -> Board: return [self.board_to_state_translation[cell] for cell in flat_state] class GridStateConverter: + """ + Converts board states to 2D grid numpy arrays for neural network input. + """ def __init__(self, shape: tuple[int, int], state_to_board_translation=None): self.shape = shape self.state_to_board_translation = state_to_board_translation or {"X": 1, "O": -1, " ": 0} @@ -62,8 +68,7 @@ def state_to_board(self, state: np.ndarray) -> Board: class OneHotStateConverter: """ - Converts board states to one-hot encoded numpy arrays of shape (1, 3, rows, rows), - where channels represent 'X', 'O', and empty respectively. + Converts board states to one-hot encoded numpy arrays for neural network input. """ def __init__(self, rows: int): @@ -100,6 +105,12 @@ def state_to_board(self, state: np.ndarray) -> Board: class DeepQLearningAgent(Agent, EvaluationMixin): """ A Deep Q-Learning agent for playing Tic Tac Toe. + + Attributes: + params (dict): Configuration parameters for the agent. + q_network (nn.Module): The Q-network for action-value estimation. + target_network (nn.Module): The target Q-network for stable training. + replay_buffer (ReplayBuffer): The replay buffer for storing experiences. """ def __init__(self, params: dict[str, Any]) -> None: @@ -121,7 +132,7 @@ def __init__(self, params: dict[str, Any]) -> None: self._override_with_shared_replay_buffer(params) self._init_symmetrized_loss(params) EvaluationMixin.__init__( - self, wandb_enabled=params["wandb_logging"], wandb_logging_frequency=params["wandb_logging_frequency"] + self, wandb_logging=params["wandb_logging"], wandb_logging_frequency=params["wandb_logging_frequency"] ) def _init_config(self, params: dict[str, Any]) -> None: @@ -140,7 +151,7 @@ def _init_config(self, params: dict[str, Any]) -> None: self.learning_rate = params["learning_rate"] self.replay_buffer_length = params["replay_buffer_length"] self.wandb_logging_frequency = params["wandb_logging_frequency"] - self.wandb = params["wandb_logging"] + self.wandb_logging = params["wandb_logging"] self.episode_count = 0 self.games_moves_count = 0 self.train_step_count = 0 @@ -158,7 +169,7 @@ def _init_wandb(self, params: dict[str, Any]) -> None: Args: params: The configuration dictionary. """ - if self.wandb: + if self.wandb_logging: wandb.init(config=params) # type: ignore def _init_group_matrices(self) -> None: @@ -587,7 +598,7 @@ def get_best_action(self, board: Board, q_network: nn.Module) -> Action: class DeepQPlayingAgent(Agent): """ - A Deep Q-Playing agent for playing Tic Tac Toe. + A Deep Q-Playing agent for playing Tic Tac Toe using a pretrained Q-network. """ def __init__(self, diff --git a/src/TicTacToe/Evaluation.py b/src/TicTacToe/Evaluation.py index 3d81e9a..6c834ca 100644 --- a/src/TicTacToe/Evaluation.py +++ b/src/TicTacToe/Evaluation.py @@ -19,7 +19,7 @@ def average_array(array: list[float] | list[int], chunk_size: Optional[int] = No Args: array (list[float] | list[int]): The input array of numbers. - chunk_size (Optional[int]): The size of each chunk. If None, it defaults to 1% of the array length. + chunk_size (Optional[int]): The size of each chunk. Defaults to 1% of the array length. Returns: list[float]: A list of averaged values for each chunk. @@ -195,10 +195,7 @@ def QAgent_plays_against_RandomAgent( Args: Q (FullySymmetricMatrix): The Q-matrix of the agent. player (Player): The player ('X' or 'O') for the Q-learning agent. - nr_of_episodes (int): Number of episodes to simulate. - rows (int): Number of rows in the TicTacToe board. - cols (int): Number of columns in the TicTacToe board. - win_length (int): Number of consecutive marks needed to win. + params (dict): Configuration parameters for the simulation. """ nr_of_episodes = params["nr_of_episodes"] playing_agent1 = QPlayingAgent(Q, player=player, switching=False) @@ -232,10 +229,7 @@ def QAgent_plays_against_QAgent( player1 (Player): The player ('X' or 'O') for the first agent. Q2 (FullySymmetricMatrix): The Q-matrix of the second agent. player2 (Player | None): The player ('X' or 'O') for the second agent. Defaults to the opposite of player1. - nr_of_episodes (int): Number of episodes to simulate. - rows (int): Number of rows in the TicTacToe board. - cols (int): Number of columns in the TicTacToe board. - win_length (int): Number of consecutive marks needed to win. + params (dict): Configuration parameters for the simulation. """ playing_agent1 = QPlayingAgent(Q1, player=player1, switching=False) if not player2: @@ -267,12 +261,7 @@ def evaluate_performance( Args: learning_agent1 (DeepQLearningAgent): The first learning agent. learning_agent2 (DeepQLearningAgent): The second learning agent. - evaluation_batch_size (int): Number of games to simulate for evaluation. - rows (int): Number of rows in the TicTacToe board. - win_length (int): Number of consecutive marks needed to win. - wandb_logging (bool): Whether to log results to Weights & Biases. - device (str): The device ('cpu' or 'cuda') for computation. - periodic (bool): Whether the board has periodic boundaries. + params (dict): Configuration parameters for evaluation. Returns: dict[str, float]: A dictionary containing evaluation metrics. diff --git a/src/TicTacToe/EvaluationMixin.py b/src/TicTacToe/EvaluationMixin.py index c9231c6..df2d2f2 100644 --- a/src/TicTacToe/EvaluationMixin.py +++ b/src/TicTacToe/EvaluationMixin.py @@ -4,8 +4,18 @@ import wandb class EvaluationMixin: - def __init__(self, wandb_enabled: bool, wandb_logging_frequency: int): - self.wandb = wandb_enabled + """ + A mixin class for logging and evaluating training metrics. + + Attributes: + wandb (bool): Whether Weights & Biases logging is enabled. + wandb_logging_frequency (int): Frequency of logging metrics to Weights & Biases. + train_step_count (int): Counter for training steps. + episode_count (int): Counter for episodes. + evaluation_data (dict): Dictionary for storing evaluation metrics. + """ + def __init__(self, wandb_logging: bool, wandb_logging_frequency: int): + self.wandb_logging = wandb_logging self.wandb_logging_frequency = wandb_logging_frequency self.train_step_count = 0 self.episode_count = 0 @@ -30,7 +40,7 @@ def maybe_log_metrics(self) -> None: if self.train_step_count % self.wandb_logging_frequency != 0: return - if self.wandb: + if self.wandb_logging: wandb.log({ "loss": self.safe_mean(self.evaluation_data["loss"]), "action_value": self.safe_mean(self.evaluation_data["action_value"]), diff --git a/src/TicTacToe/QNetworks.py b/src/TicTacToe/QNetworks.py index b95fca8..af45185 100644 --- a/src/TicTacToe/QNetworks.py +++ b/src/TicTacToe/QNetworks.py @@ -9,7 +9,7 @@ class QNetwork(nn.Module): """ - A neural network for approximating the Q-function. + A fully connected neural network for approximating the Q-function. """ def __init__(self, input_dim: int, output_dim: int) -> None: @@ -31,7 +31,7 @@ def __init__(self, input_dim: int, output_dim: int) -> None: def forward(self, x: torch.Tensor) -> torch.Tensor: """ - Forward pass of the QNetwork. + Perform a forward pass through the QNetwork. Args: x: Input tensor. @@ -77,7 +77,7 @@ def __init__(self, input_dim: int, rows: int, output_dim: int) -> None: def forward(self, x: torch.Tensor) -> torch.Tensor: """ - Forward pass of the CNNQNetwork. + Perform a forward pass through the CNNQNetwork. Args: x: Input tensor of shape (batch_size, input_dim, grid_size, grid_size). @@ -93,13 +93,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: else: raise ValueError(f"Unexpected input shape: {x.shape}") - x = x.view(-1, 1, self.rows, self.rows) x = self.conv_layers(x) x = self.fc_layers(x) return x.view(x.size(0), -1) # shape: (batch_size, rows * cols) class PeriodicConvBase(nn.Module): + """ + A base convolutional module with periodic padding for symmetry handling. + """ + def __init__(self, input_dim: int = 1, padding_mode: str = 'circular'): + """ + Initialize the PeriodicConvBase. + + Args: + input_dim: Number of input channels. + padding_mode: Padding mode for convolutional layers. + """ super().__init__() self.encoder = nn.Sequential( nn.Conv2d(input_dim, 32, kernel_size=3, stride=1, padding=1, padding_mode=padding_mode), @@ -109,21 +119,60 @@ def __init__(self, input_dim: int = 1, padding_mode: str = 'circular'): ) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Perform a forward pass through the PeriodicConvBase. + + Args: + x: Input tensor. + + Returns: + Output tensor after applying the convolutional layers. + """ return self.encoder(x) class PeriodicQHead(nn.Module): + """ + A convolutional head for generating Q-values with periodic padding. + """ + def __init__(self, padding_mode: str = 'circular'): + """ + Initialize the PeriodicQHead. + + Args: + padding_mode: Padding mode for the convolutional layer. + """ super().__init__() self.head = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1, padding_mode=padding_mode) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Perform a forward pass through the PeriodicQHead. + + Args: + x: Input tensor. + + Returns: + Output tensor with Q-values. + """ x = self.head(x) # shape: (batch_size, 1, rows, cols) return x.squeeze(1) # shape: (batch_size, rows, cols) class FullyConvQNetwork(nn.Module): + """ + A fully convolutional Q-network with periodic padding for symmetry handling. + """ + def __init__(self, input_dim: int = 1, padding_mode: str = 'circular'): + """ + Initialize the FullyConvQNetwork. + + Args: + input_dim: Number of input channels. + padding_mode: Padding mode for convolutional layers. + """ super().__init__() self.base = PeriodicConvBase(input_dim=input_dim, padding_mode=padding_mode) self.head = PeriodicQHead(padding_mode=padding_mode) @@ -138,6 +187,15 @@ def __init__(self, input_dim: int = 1, padding_mode: str = 'circular'): ) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Perform a forward pass through the FullyConvQNetwork. + + Args: + x: Input tensor. + + Returns: + Output tensor with Q-values. + """ if x.ndim == 2: # (batch_size, rows*cols) spatial_dim = int(x.size(1) ** 0.5) x = x.view(x.size(0), 1, spatial_dim, spatial_dim) @@ -289,8 +347,8 @@ def __init__(self, weight_pattern: torch.Tensor, bias_pattern: torch.Tensor): Initialize the EquivariantLayer. Args: - weight_pattern (torch.Tensor): Tensor defining the weight tying pattern. - bias_pattern (torch.Tensor): Tensor defining the bias tying pattern. + weight_pattern: Tensor defining the weight tying pattern. + bias_pattern: Tensor defining the bias tying pattern. """ super(EquivariantLayer, self).__init__() # type: ignore @@ -321,10 +379,10 @@ def _get_nr_of_unique_nonzero_elements(self, pattern: torch.Tensor) -> int: Get the number of unique non-zero elements in a pattern. Args: - pattern (torch.Tensor): The input pattern. + pattern: The input pattern. Returns: - int: The number of unique non-zero elements. + The number of unique non-zero elements. """ unique_elements = list(set(pattern.detach().numpy().flatten())) # type: ignore if 0 in unique_elements: @@ -337,10 +395,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Perform a forward pass through the layer. Args: - x (torch.Tensor): Input tensor. + x: Input tensor. Returns: - torch.Tensor: Output tensor after applying the layer. + Output tensor after applying the layer. """ weight = torch.zeros_like(self.weight_mask, dtype=torch.float32) bias = torch.zeros_like(self.bias_mask, dtype=torch.float32) @@ -353,7 +411,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class EquivariantNN(nn.Module): """ - A neural network with multiple equivariant layers. + A neural network with multiple equivariant layers for symmetry-aware learning. """ def __init__(self, groupMatrices: list[Any], ms: Tuple[int, int, int, int] = (1, 5, 5, 1)) -> None: @@ -361,8 +419,8 @@ def __init__(self, groupMatrices: list[Any], ms: Tuple[int, int, int, int] = (1, Initialize the EquivariantNN. Args: - groupMatrices (list[Any]): List of transformation matrices. - ms (Tuple[int, int, int, int]): Dimensions for each layer. + groupMatrices: List of transformation matrices. + ms: Dimensions for each layer. """ super(EquivariantNN, self).__init__() # type: ignore @@ -391,9 +449,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Perform a forward pass through the network. Args: - x (torch.Tensor): Input tensor. + x: Input tensor. Returns: - torch.Tensor: Output tensor after applying the network. + Output tensor after applying the network. """ return self.fc_equivariant(x) diff --git a/src/TicTacToe/ReplayBuffers.py b/src/TicTacToe/ReplayBuffers.py index 9de39fc..09a80c2 100644 --- a/src/TicTacToe/ReplayBuffers.py +++ b/src/TicTacToe/ReplayBuffers.py @@ -5,7 +5,9 @@ class BaseReplayBuffer: """ - Base class interface for all replay buffers. + Abstract base class for replay buffers. + + Defines the interface for adding experiences, sampling batches, and retrieving buffer size. """ def add(self, state: State, action: Action, reward: Reward, next_state: State, done: bool) -> None: raise NotImplementedError @@ -19,7 +21,7 @@ def __len__(self) -> int: class ReplayBuffer(BaseReplayBuffer): """ - Standard uniform sampling replay buffer. + A standard replay buffer with uniform sampling. """ def __init__(self, size: int, state_shape: Tuple[int, ...], device: str = "cpu") -> None: self.size = size @@ -74,7 +76,7 @@ def __len__(self) -> int: class PrioritizedReplayBuffer(ReplayBuffer): """ - Prioritized Experience Replay buffer. + A replay buffer that implements prioritized experience replay. """ def __init__(self, size, state_shape, device="cpu", alpha=0.6, beta=0.4): super().__init__(size, state_shape, device) diff --git a/src/TicTacToe/SymmetricMatrix.py b/src/TicTacToe/SymmetricMatrix.py index 6168c50..b0fd2c5 100644 --- a/src/TicTacToe/SymmetricMatrix.py +++ b/src/TicTacToe/SymmetricMatrix.py @@ -13,7 +13,10 @@ class LazyComputeDict(dict[Any, Any]): """ - A dictionary that computes and stores values lazily using a provided function. + A dictionary that computes missing values lazily using a provided function. + + Attributes: + compute_func (Callable): A function to compute values for missing keys. """ def __init__(self, compute_func: Callable[..., Any], *args: Any, **kwargs: Any) -> None: @@ -45,6 +48,7 @@ def __getitem__(self, key: Any): class BaseMatrix(ABC): """ Abstract base class for Q-value matrices. + Defines the shared interface and enforces implementation of key methods. """ @@ -101,7 +105,7 @@ def set(self, board: Board, action: Action, value: float) -> None: class Matrix(BaseMatrix): """ - A concrete implementation of BaseMatrix for storing Q-values in a standard matrix. + A standard implementation of BaseMatrix for storing Q-values in a dictionary. """ def __init__(self, file: str | None = None, default_value: float | None = None) -> None: @@ -489,7 +493,7 @@ def set(self, board: Board, action: int, value: float) -> None: class FullySymmetricMatrix(SymmetricMatrix): """ - A matrix that leverages full board symmetries, including next states, to reduce the number of stored Q-values. + A matrix that leverages full board symmetries, including next states, for Q-value storage. """ def __init__( diff --git a/src/TicTacToe/TicTacToe.py b/src/TicTacToe/TicTacToe.py index ff9e9f4..b42cb92 100644 --- a/src/TicTacToe/TicTacToe.py +++ b/src/TicTacToe/TicTacToe.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List, Optional, Tuple, Dict, Any +from typing import List, Optional, Tuple from TicTacToe.Agent import Agent, MouseAgent from TicTacToe.Display import Display, ScreenDisplay diff --git a/src/TicTacToe/Utils.py b/src/TicTacToe/Utils.py index 038a7e8..834bb29 100644 --- a/src/TicTacToe/Utils.py +++ b/src/TicTacToe/Utils.py @@ -23,26 +23,26 @@ def get_param_sweep_combinations(param_sweep: dict) -> tuple[list[tuple[Any, ... Generate all combinations of hyperparameter values for parameter sweeping. Args: - param_sweep (dict): Dictionary mapping parameter names to lists of values. + param_sweep (dict): A dictionary where keys are parameter names and values are lists of possible values. Returns: tuple: A tuple containing: - - List of all combinations of parameter values as tuples. - - List of parameter keys in the same order as the combinations. + - A list of all combinations of parameter values as tuples. + - A list of parameter keys in the same order as the combinations. """ return list(product(*param_sweep.values())), list(param_sweep.keys()) def load_pretrained_models(paramsX: dict, paramsO: dict) -> tuple[dict, dict]: """ - Load pretrained models from disk and update parameter dictionaries with model paths. + Load pretrained models for players X and O from disk and update their parameter dictionaries. Args: - paramsX (dict): Parameters for player X. - paramsO (dict): Parameters for player O. + paramsX (dict): Parameter dictionary for player X. + paramsO (dict): Parameter dictionary for player O. Returns: - tuple: Updated (paramsX, paramsO) with paths to pretrained models. + tuple: A tuple containing updated parameter dictionaries for players X and O. """ script_dir = Path(__file__).resolve().parent relative_folder = (script_dir / '../models/foundational').resolve() @@ -60,13 +60,12 @@ def load_pretrained_models(paramsX: dict, paramsO: dict) -> tuple[dict, dict]: def save_model_artifacts(agent1: Agent, agent2: Agent, params: dict): """ - Save full models and weight components for both agents, store metadata in a dedicated folder, - and append summary to a central index. + Save model artifacts for both agents, including full models, weights, and metadata. Args: - agent1 (Agent): Agent playing as 'X'. - agent2 (Agent): Agent playing as 'O'. - params (dict): Parameter configuration. + agent1 (Agent): The agent playing as 'X'. + agent2 (Agent): The agent playing as 'O'. + params (dict): Configuration parameters for saving models. """ # Prepare folders base_folder = Path(params["save_models"]).resolve() @@ -139,20 +138,21 @@ def save_agent(agent, player): with open(index_file, "w") as f: json.dump(index_data, f, indent=4) -def update_exploration_rate_smoothly(agent1: DeepQLearningAgent, agent2: DeepQLearningAgent, params: dict, eval_data: dict, exploration_rate: float, win_rate_deques: tuple[deque, deque], wandb_logging = True): +def update_exploration_rate_smoothly(agent1: DeepQLearningAgent, agent2: DeepQLearningAgent, params: dict, eval_data: dict, exploration_rate: float, win_rate_deques: tuple[deque, deque], wandb_logging=True): """ - Update the exploration rate based on smoothed averages of recent win rates. + Smoothly update the exploration rate based on recent win rates. Args: - agent1 (DeepQLearningAgent): Agent playing as 'X'. - agent2 (DeepQLearningAgent): Agent playing as 'O'. - params (dict): Parameter configuration. + agent1 (DeepQLearningAgent): The agent playing as 'X'. + agent2 (DeepQLearningAgent): The agent playing as 'O'. + params (dict): Configuration parameters. eval_data (dict): Evaluation data containing win rates. exploration_rate (float): Current exploration rate. - win_rate_deques (tuple): Two deques storing recent win rates for 'X' and 'O'. + win_rate_deques (tuple): Deques storing recent win rates for 'X' and 'O'. + wandb_logging (bool): Whether to log metrics to Weights & Biases. Returns: - float: Updated exploration rate. + float: The updated exploration rate. """ X_win_rates, O_win_rates = win_rate_deques @@ -197,14 +197,13 @@ def update_exploration_rate_smoothly(agent1: DeepQLearningAgent, agent2: DeepQLe def train_and_evaluate(game: TwoPlayerBoardGame, agent1: DeepQLearningAgent, agent2: DeepQLearningAgent, params: dict): """ - Train and evaluate two agents in a Tic Tac Toe game. + Train and evaluate two agents in a Tic Tac Toe game environment. Args: - game (TwoPlayerBoardGame): The game environment. - agent1 (DeepQLearningAgent): Agent playing as 'X'. - agent2 (DeepQLearningAgent): Agent playing as 'O'. - params (dict): Parameter configuration. - wandb_logging (bool): Whether to log to Weights & Biases. + game (TwoPlayerBoardGame): The game environment instance. + agent1 (DeepQLearningAgent): The agent playing as 'X'. + agent2 (DeepQLearningAgent): The agent playing as 'O'. + params (dict): Configuration parameters for training and evaluation. """ wandb_logging = params["wandb_logging"] diff --git a/tests/test_integration.py b/tests/test_integration.py index 805991c..3dfbffa 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -130,6 +130,97 @@ def test_deep_q_agent_training_CNN(self) -> None: self.assertGreater(len(agent1.replay_buffer), 0, "Replay buffer should contain experiences after training.") + def test_deep_q_agent_training_CNN_2D_state(self) -> None: + """Simulate training of a DeepQLearningAgent during gameplay.""" + self.params["network_type"] = "CNN" + self.params["periodic"] = False + self.params["state_shape"] = "2D" + self.params["rows"] = 3 + + agent1 = DeepQLearningAgent(params=self.params) + agent2 = RandomAgent(player="O") + game = TicTacToe(agent1, agent2, params=self.params) + + # Simulate multiple episodes to test training + for episode in range(10): + outcome = game.play() + self.assertIn(outcome, ["X", "O", "D"], f"Game outcome in episode {episode} should be valid.") + + self.assertGreater(len(agent1.replay_buffer), 0, "Replay buffer should contain experiences after training.") + + def test_deep_q_agent_training_CNN_one_hot_state(self) -> None: + """Simulate training of a DeepQLearningAgent during gameplay.""" + self.params["network_type"] = "CNN" + self.params["periodic"] = False + self.params["state_shape"] = "one-hot" + self.params["rows"] = 3 + + agent1 = DeepQLearningAgent(params=self.params) + agent2 = RandomAgent(player="O") + game = TicTacToe(agent1, agent2, params=self.params) + + # Simulate multiple episodes to test training + for episode in range(10): + outcome = game.play() + self.assertIn(outcome, ["X", "O", "D"], f"Game outcome in episode {episode} should be valid.") + + self.assertGreater(len(agent1.replay_buffer), 0, "Replay buffer should contain experiences after training.") + + def test_deep_q_agent_training_CNN_more_rows(self) -> None: + """Simulate training of a DeepQLearningAgent during gameplay.""" + self.params["network_type"] = "CNN" + self.params["periodic"] = False + self.params["state_shape"] = "flat" + self.params["rows"] = 5 + + agent1 = DeepQLearningAgent(params=self.params) + agent2 = RandomAgent(player="O") + game = TicTacToe(agent1, agent2, params=self.params) + + # Simulate multiple episodes to test training + for episode in range(10): + outcome = game.play() + self.assertIn(outcome, ["X", "O", "D"], f"Game outcome in episode {episode} should be valid.") + + self.assertGreater(len(agent1.replay_buffer), 0, "Replay buffer should contain experiences after training.") + + def test_deep_q_agent_training_CNN_periodic(self) -> None: + """Simulate training of a DeepQLearningAgent during gameplay.""" + self.params["network_type"] = "CNN" + self.params["periodic"] = True + self.params["state_shape"] = "flat" + self.params["rows"] = 3 + + agent1 = DeepQLearningAgent(params=self.params) + agent2 = RandomAgent(player="O") + game = TicTacToe(agent1, agent2, params=self.params) + + + # Simulate multiple episodes to test training + for episode in range(10): + outcome = game.play() + self.assertIn(outcome, ["X", "O", "D"], f"Game outcome in episode {episode} should be valid.") + + self.assertGreater(len(agent1.replay_buffer), 0, "Replay buffer should contain experiences after training.") + + def test_deep_q_agent_training_FullyCNN(self) -> None: + """Simulate training of a DeepQLearningAgent during gameplay.""" + self.params["network_type"] = "FullyCNN" + self.params["periodic"] = False + self.params["state_shape"] = "flat" + self.params["rows"] = 3 + + agent1 = DeepQLearningAgent(params=self.params) + agent2 = RandomAgent(player="O") + game = TicTacToe(agent1, agent2, params=self.params) + + # Simulate multiple episodes to test training + for episode in range(10): + outcome = game.play() + self.assertIn(outcome, ["X", "O", "D"], f"Game outcome in episode {episode} should be valid.") + + self.assertGreater(len(agent1.replay_buffer), 0, "Replay buffer should contain experiences after training.") + def test_deep_q_agent_training_FullyCNN_periodic(self) -> None: """Simulate training of a DeepQLearningAgent during gameplay.""" self.params["network_type"] = "FullyCNN" @@ -166,18 +257,17 @@ def test_deep_q_agent_training_FullyCNN_periodic_2D_state(self) -> None: self.assertGreater(len(agent1.replay_buffer), 0, "Replay buffer should contain experiences after training.") - def test_deep_q_agent_training_CNN_periodic(self) -> None: + def test_deep_q_agent_training_FullyCNN_periodic_one_hot_state(self) -> None: """Simulate training of a DeepQLearningAgent during gameplay.""" - self.params["network_type"] = "CNN" + self.params["network_type"] = "FullyCNN" self.params["periodic"] = True - self.params["state_shape"] = "flat" + self.params["state_shape"] = "one-hot" self.params["rows"] = 3 agent1 = DeepQLearningAgent(params=self.params) agent2 = RandomAgent(player="O") game = TicTacToe(agent1, agent2, params=self.params) - # Simulate multiple episodes to test training for episode in range(10): outcome = game.play() diff --git a/tests/test_utils.py b/tests/test_utils.py index f75f130..2c0243e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -13,7 +13,6 @@ from TicTacToe.Utils import ( get_param_sweep_combinations, load_pretrained_models, - save_model_artifacts, update_exploration_rate_smoothly, train_and_evaluate )