Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions tests/test_ssl_heads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import chess
import torch

from azchess.model.resnet import NetConfig, PolicyValueNet
from azchess.ssl_algorithms import ChessSSLAlgorithms
from experiments.grpo.utils.board_encoding import board_to_tensor


def _make_batch(batch_size: int = 2) -> torch.Tensor:
sequences = [
[],
["e4", "e5", "Nf3", "Nc6"],
["d4", "d5", "c4", "dxc4", "Nf3", "Nf6"],
]

boards = []
for seq in sequences:
board = chess.Board()
for san in seq:
board.push_san(san)
boards.append(board)

if batch_size > len(boards):
boards.extend(boards[: batch_size - len(boards)])

tensors = [board_to_tensor(b) for b in boards[:batch_size]]
Comment on lines +23 to +26
Copy link

Copilot AI Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The logic for extending boards when batch_size exceeds available boards could be clearer. Consider using modulo operation or explicit repetition with a comment explaining the behavior.

Suggested change
if batch_size > len(boards):
boards.extend(boards[: batch_size - len(boards)])
tensors = [board_to_tensor(b) for b in boards[:batch_size]]
# Repeat boards as necessary to ensure we have at least batch_size elements
import itertools
boards = list(itertools.islice(itertools.cycle(boards), batch_size))
tensors = [board_to_tensor(b) for b in boards]

Copilot uses AI. Check for mistakes.
return torch.cat(tensors, dim=0)


def test_enhanced_ssl_heads_outputs_and_loss():
torch.manual_seed(0)
batch = _make_batch(batch_size=3)

cfg = NetConfig(
channels=32,
blocks=2,
attention=False,
chess_features=False,
self_supervised=True,
ssl_tasks=["piece", "threat", "pin", "fork", "control"],
)

model = PolicyValueNet(cfg)
model.eval()

# Forward pass to obtain SSL head outputs and shared features
_, _, ssl_outputs, feats = model.forward_with_features(batch, return_ssl=True)

expected_shapes = {
"piece": (batch.size(0), 13, 8, 8),
"threat": (batch.size(0), 1, 8, 8),
"pin": (batch.size(0), 1, 8, 8),
"fork": (batch.size(0), 1, 8, 8),
"control": (batch.size(0), 3, 8, 8),
}

assert set(ssl_outputs.keys()) == set(expected_shapes.keys())
for task, expected_shape in expected_shapes.items():
output = ssl_outputs[task]
assert output.shape == expected_shape
assert output.dtype == batch.dtype

ssl_alg = ChessSSLAlgorithms()
targets = ssl_alg.create_enhanced_ssl_targets(batch)

loss = model.get_enhanced_ssl_loss(batch, targets, feats=feats)
assert torch.isfinite(loss).item()
assert loss.dtype == torch.float32

stats_keys = {k for k in model._ssl_loss_stats if k.startswith("task:")}
assert stats_keys == {f"task:{name}" for name in expected_shapes.keys()}

# Regression: disabling tasks should immediately reflect in loss tracking
cfg_subset = NetConfig(
channels=32,
blocks=2,
attention=False,
chess_features=False,
self_supervised=True,
ssl_tasks=["piece", "threat"],
)

model_subset = PolicyValueNet(cfg_subset)
model_subset.eval()

_, _, _, feats_subset = model_subset.forward_with_features(batch, return_ssl=True)
model_subset.get_enhanced_ssl_loss(batch, targets, feats=feats_subset)

subset_stats_keys = {k for k in model_subset._ssl_loss_stats if k.startswith("task:")}
assert subset_stats_keys == {"task:piece", "task:threat"}
Loading