-
Notifications
You must be signed in to change notification settings - Fork 0
Propagate norm and activation config to chess features #99
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -196,39 +196,51 @@ def _create_chess_attention_mask(self, height: int, width: int, device: torch.de | |
|
|
||
| class ChessSpecificFeatures(nn.Module): | ||
| """Chess-specific feature extraction and enhancement.""" | ||
|
|
||
| def __init__(self, channels: int, piece_square_tables: bool = True): | ||
|
|
||
| def __init__( | ||
| self, | ||
| channels: int, | ||
| piece_square_tables: bool = True, | ||
| norm: str = "batch", | ||
| activation: str = "relu", | ||
| ): | ||
| super().__init__() | ||
| self.piece_square_tables = piece_square_tables | ||
|
|
||
|
|
||
| activation_cls = nn.SiLU if activation == "silu" else nn.ReLU | ||
|
|
||
| if piece_square_tables: | ||
| # Piece-square table features | ||
| self.pst_conv = nn.Conv2d(channels, channels, kernel_size=1, bias=False) | ||
| self.pst_norm = _norm(channels) | ||
|
|
||
| self.pst_norm = _norm(channels, norm) | ||
|
||
| self.pst_activation = activation_cls(inplace=True) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Conditional Attribute Initialization Causes API InconsistencyThe |
||
|
|
||
| # Chess-specific convolutions for piece interactions | ||
| self.interaction_conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False) | ||
| self.interaction_norm = _norm(channels) | ||
|
|
||
| self.interaction_norm = _norm(channels, norm) | ||
|
||
| self.interaction_activation = activation_cls(inplace=True) | ||
|
|
||
| # Position encoding for chess board (8x8 for chess) | ||
| self.position_encoding = nn.Parameter(torch.randn(1, channels, 8, 8)) | ||
|
|
||
| # Initialize position encoding properly | ||
| nn.init.normal_(self.position_encoding, mean=0.0, std=0.1) | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| # Add position encoding | ||
| x = x + self.position_encoding | ||
|
|
||
| if self.piece_square_tables: | ||
| # Apply PST features | ||
| pst_features = F.relu(self.pst_norm(self.pst_conv(x))) | ||
| pst_features = self.pst_activation(self.pst_norm(self.pst_conv(x))) | ||
| x = x + pst_features | ||
|
|
||
| # Apply interaction features | ||
| interaction_features = F.relu(self.interaction_norm(self.interaction_conv(x))) | ||
| interaction_features = self.interaction_activation( | ||
| self.interaction_norm(self.interaction_conv(x)) | ||
| ) | ||
| x = x + interaction_features | ||
|
|
||
| return x | ||
|
|
||
|
|
||
|
|
@@ -299,7 +311,12 @@ def __init__(self, cfg: NetConfig): | |
|
|
||
| # Add chess-specific features if enabled | ||
| if cfg.chess_features: | ||
| self.chess_features = ChessSpecificFeatures(C, cfg.piece_square_tables) | ||
| self.chess_features = ChessSpecificFeatures( | ||
| C, | ||
| cfg.piece_square_tables, | ||
| norm=cfg.norm, | ||
| activation=cfg.activation, | ||
| ) | ||
| else: | ||
| self.chess_features = None | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| import pytest | ||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| from azchess.model.resnet import NetConfig, PolicyValueNet, ChessSpecificFeatures | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("norm", ["batch", "group"]) | ||
| @pytest.mark.parametrize("activation", ["relu", "silu"]) | ||
| def test_chess_specific_features_respects_config(norm: str, activation: str) -> None: | ||
| cfg = NetConfig( | ||
| planes=19, | ||
| channels=64, | ||
| blocks=1, | ||
| policy_size=4672, | ||
| se=False, | ||
| attention=False, | ||
| chess_features=True, | ||
| piece_square_tables=True, | ||
| self_supervised=False, | ||
| norm=norm, | ||
| activation=activation, | ||
| ) | ||
|
|
||
| model = PolicyValueNet(cfg) | ||
|
|
||
| assert isinstance(model.chess_features, ChessSpecificFeatures) | ||
|
|
||
| if norm == "batch": | ||
| expected_norm = nn.BatchNorm2d | ||
| else: | ||
| expected_norm = nn.GroupNorm | ||
|
|
||
| assert isinstance(model.chess_features.pst_norm, expected_norm) | ||
| assert isinstance(model.chess_features.interaction_norm, expected_norm) | ||
|
|
||
| expected_activation = nn.SiLU if activation == "silu" else nn.ReLU | ||
| assert isinstance(model.chess_features.pst_activation, expected_activation) | ||
| assert isinstance(model.chess_features.interaction_activation, expected_activation) | ||
|
|
||
| x = torch.randn(2, cfg.planes, 8, 8) | ||
| policy, value = model(x, return_ssl=False) | ||
| assert policy.shape == (2, cfg.policy_size) | ||
| assert value.shape == (2,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The activation selection logic only handles 'silu' and defaults to ReLU for all other values. Consider adding explicit validation or support for other common activations to make the behavior more predictable.