Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions azchess/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,14 @@ def _get_tactical_batch(self, batch_size: int) -> Optional[Dict[str, np.ndarray]

try:
with np.load(tactical_path, allow_pickle=True) as data:
indices = np.random.choice(len(data['positions']), batch_size, replace=False)
total_positions = len(data['positions'])
if total_positions == 0:
logger.warning("Tactical training data is empty")
return None
draw_size = min(batch_size, total_positions)
replace = batch_size > total_positions
target_size = batch_size if replace else draw_size
Comment on lines +457 to +459
Copy link

Copilot AI Oct 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The variable draw_size is misleading since it represents the sample size when replacement is disabled, not the size being drawn. Consider renaming to sample_size_no_replace or clamped_size for clarity.

Suggested change
draw_size = min(batch_size, total_positions)
replace = batch_size > total_positions
target_size = batch_size if replace else draw_size
sample_size_no_replace = min(batch_size, total_positions)
replace = batch_size > total_positions
target_size = batch_size if replace else sample_size_no_replace

Copilot uses AI. Check for mistakes.
indices = np.random.choice(total_positions, target_size, replace=replace)
batch_states = data['positions'][indices] # curriculum format: (N, 19, 8, 8)
batch_policies = data['policy_targets'][indices] # curriculum format: (N, 4672)
batch_values = data['value_targets'][indices] # curriculum format: (N,)
Expand Down Expand Up @@ -496,7 +503,14 @@ def _get_openings_batch(self, batch_size: int) -> Optional[Dict[str, np.ndarray]

try:
with np.load(openings_path, allow_pickle=True) as data:
indices = np.random.choice(len(data['positions']), batch_size, replace=False)
total_positions = len(data['positions'])
if total_positions == 0:
logger.warning("Openings training data is empty")
return None
draw_size = min(batch_size, total_positions)
replace = batch_size > total_positions
target_size = batch_size if replace else draw_size
Comment on lines +510 to +512
Copy link

Copilot AI Oct 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The variable draw_size is misleading since it represents the sample size when replacement is disabled, not the size being drawn. Consider renaming to sample_size_no_replace or clamped_size for clarity.

Suggested change
draw_size = min(batch_size, total_positions)
replace = batch_size > total_positions
target_size = batch_size if replace else draw_size
clamped_size = min(batch_size, total_positions)
replace = batch_size > total_positions
target_size = batch_size if replace else clamped_size

Copilot uses AI. Check for mistakes.
indices = np.random.choice(total_positions, target_size, replace=replace)
batch_states = data['positions'][indices] # curriculum format: (N, 19, 8, 8)
batch_policies = data['policy_targets'][indices] # curriculum format: (N, 4672)
batch_values = data['value_targets'][indices] # curriculum format: (N,)
Expand Down
49 changes: 49 additions & 0 deletions tests/test_data_manager_batches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import numpy as np
import chess

from azchess.data_manager import DataManager
from azchess.encoding import encode_board


POLICY_SIZE = 4672


def _write_mock_npz(base_dir, subdir, filename, sample_count):
path = base_dir / subdir
path.mkdir(parents=True, exist_ok=True)
board = chess.Board()
positions = np.stack([encode_board(board) for _ in range(sample_count)], axis=0)
policy_targets = np.zeros((sample_count, POLICY_SIZE), dtype=np.float32)
value_targets = np.zeros((sample_count,), dtype=np.float32)
np.savez(
path / filename,
positions=positions,
policy_targets=policy_targets,
value_targets=value_targets,
)


def test_tactical_batch_handles_small_file(tmp_path):
_write_mock_npz(tmp_path, "tactical", "tactical_positions.npz", sample_count=2)
manager = DataManager(base_dir=str(tmp_path))

batch = manager._get_tactical_batch(batch_size=4)

assert batch is not None
assert batch['s'].shape == (4, 19, 8, 8)
assert batch['pi'].shape == (4, POLICY_SIZE)
assert batch['z'].shape == (4,)
assert batch['legal_mask'].shape == (4, POLICY_SIZE)


def test_openings_batch_handles_small_file(tmp_path):
_write_mock_npz(tmp_path, "openings", "openings_positions.npz", sample_count=1)
manager = DataManager(base_dir=str(tmp_path))

batch = manager._get_openings_batch(batch_size=3)

assert batch is not None
assert batch['s'].shape == (3, 19, 8, 8)
assert batch['pi'].shape == (3, POLICY_SIZE)
assert batch['z'].shape == (3,)
assert batch['legal_mask'].shape == (3, POLICY_SIZE)
Loading