diff --git a/azchess/model/resnet.py b/azchess/model/resnet.py index f2b70c5..cdc9b21 100644 --- a/azchess/model/resnet.py +++ b/azchess/model/resnet.py @@ -196,39 +196,51 @@ def _create_chess_attention_mask(self, height: int, width: int, device: torch.de class ChessSpecificFeatures(nn.Module): """Chess-specific feature extraction and enhancement.""" - - def __init__(self, channels: int, piece_square_tables: bool = True): + + def __init__( + self, + channels: int, + piece_square_tables: bool = True, + norm: str = "batch", + activation: str = "relu", + ): super().__init__() self.piece_square_tables = piece_square_tables - + + activation_cls = nn.SiLU if activation == "silu" else nn.ReLU + if piece_square_tables: # Piece-square table features self.pst_conv = nn.Conv2d(channels, channels, kernel_size=1, bias=False) - self.pst_norm = _norm(channels) - + self.pst_norm = _norm(channels, norm) + self.pst_activation = activation_cls(inplace=True) + # Chess-specific convolutions for piece interactions self.interaction_conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False) - self.interaction_norm = _norm(channels) - + self.interaction_norm = _norm(channels, norm) + self.interaction_activation = activation_cls(inplace=True) + # Position encoding for chess board (8x8 for chess) self.position_encoding = nn.Parameter(torch.randn(1, channels, 8, 8)) - + # Initialize position encoding properly nn.init.normal_(self.position_encoding, mean=0.0, std=0.1) - + def forward(self, x: torch.Tensor) -> torch.Tensor: # Add position encoding x = x + self.position_encoding - + if self.piece_square_tables: # Apply PST features - pst_features = F.relu(self.pst_norm(self.pst_conv(x))) + pst_features = self.pst_activation(self.pst_norm(self.pst_conv(x))) x = x + pst_features - + # Apply interaction features - interaction_features = F.relu(self.interaction_norm(self.interaction_conv(x))) + interaction_features = self.interaction_activation( + self.interaction_norm(self.interaction_conv(x)) + ) x = x + interaction_features - + return x @@ -299,7 +311,12 @@ def __init__(self, cfg: NetConfig): # Add chess-specific features if enabled if cfg.chess_features: - self.chess_features = ChessSpecificFeatures(C, cfg.piece_square_tables) + self.chess_features = ChessSpecificFeatures( + C, + cfg.piece_square_tables, + norm=cfg.norm, + activation=cfg.activation, + ) else: self.chess_features = None diff --git a/tests/test_chess_features.py b/tests/test_chess_features.py new file mode 100644 index 0000000..1bcc0ac --- /dev/null +++ b/tests/test_chess_features.py @@ -0,0 +1,44 @@ +import pytest +import torch +import torch.nn as nn + +from azchess.model.resnet import NetConfig, PolicyValueNet, ChessSpecificFeatures + + +@pytest.mark.parametrize("norm", ["batch", "group"]) +@pytest.mark.parametrize("activation", ["relu", "silu"]) +def test_chess_specific_features_respects_config(norm: str, activation: str) -> None: + cfg = NetConfig( + planes=19, + channels=64, + blocks=1, + policy_size=4672, + se=False, + attention=False, + chess_features=True, + piece_square_tables=True, + self_supervised=False, + norm=norm, + activation=activation, + ) + + model = PolicyValueNet(cfg) + + assert isinstance(model.chess_features, ChessSpecificFeatures) + + if norm == "batch": + expected_norm = nn.BatchNorm2d + else: + expected_norm = nn.GroupNorm + + assert isinstance(model.chess_features.pst_norm, expected_norm) + assert isinstance(model.chess_features.interaction_norm, expected_norm) + + expected_activation = nn.SiLU if activation == "silu" else nn.ReLU + assert isinstance(model.chess_features.pst_activation, expected_activation) + assert isinstance(model.chess_features.interaction_activation, expected_activation) + + x = torch.randn(2, cfg.planes, 8, 8) + policy, value = model(x, return_ssl=False) + assert policy.shape == (2, cfg.policy_size) + assert value.shape == (2,)