diff --git a/azchess/data_manager.py b/azchess/data_manager.py index f1d43e9..ed0816f 100644 --- a/azchess/data_manager.py +++ b/azchess/data_manager.py @@ -450,7 +450,14 @@ def _get_tactical_batch(self, batch_size: int) -> Optional[Dict[str, np.ndarray] try: with np.load(tactical_path, allow_pickle=True) as data: - indices = np.random.choice(len(data['positions']), batch_size, replace=False) + total_positions = len(data['positions']) + if total_positions == 0: + logger.warning("Tactical training data is empty") + return None + draw_size = min(batch_size, total_positions) + replace = batch_size > total_positions + target_size = batch_size if replace else draw_size + indices = np.random.choice(total_positions, target_size, replace=replace) batch_states = data['positions'][indices] # curriculum format: (N, 19, 8, 8) batch_policies = data['policy_targets'][indices] # curriculum format: (N, 4672) batch_values = data['value_targets'][indices] # curriculum format: (N,) @@ -496,7 +503,14 @@ def _get_openings_batch(self, batch_size: int) -> Optional[Dict[str, np.ndarray] try: with np.load(openings_path, allow_pickle=True) as data: - indices = np.random.choice(len(data['positions']), batch_size, replace=False) + total_positions = len(data['positions']) + if total_positions == 0: + logger.warning("Openings training data is empty") + return None + draw_size = min(batch_size, total_positions) + replace = batch_size > total_positions + target_size = batch_size if replace else draw_size + indices = np.random.choice(total_positions, target_size, replace=replace) batch_states = data['positions'][indices] # curriculum format: (N, 19, 8, 8) batch_policies = data['policy_targets'][indices] # curriculum format: (N, 4672) batch_values = data['value_targets'][indices] # curriculum format: (N,) diff --git a/tests/test_data_manager_batches.py b/tests/test_data_manager_batches.py new file mode 100644 index 0000000..fdc693d --- /dev/null +++ b/tests/test_data_manager_batches.py @@ -0,0 +1,49 @@ +import numpy as np +import chess + +from azchess.data_manager import DataManager +from azchess.encoding import encode_board + + +POLICY_SIZE = 4672 + + +def _write_mock_npz(base_dir, subdir, filename, sample_count): + path = base_dir / subdir + path.mkdir(parents=True, exist_ok=True) + board = chess.Board() + positions = np.stack([encode_board(board) for _ in range(sample_count)], axis=0) + policy_targets = np.zeros((sample_count, POLICY_SIZE), dtype=np.float32) + value_targets = np.zeros((sample_count,), dtype=np.float32) + np.savez( + path / filename, + positions=positions, + policy_targets=policy_targets, + value_targets=value_targets, + ) + + +def test_tactical_batch_handles_small_file(tmp_path): + _write_mock_npz(tmp_path, "tactical", "tactical_positions.npz", sample_count=2) + manager = DataManager(base_dir=str(tmp_path)) + + batch = manager._get_tactical_batch(batch_size=4) + + assert batch is not None + assert batch['s'].shape == (4, 19, 8, 8) + assert batch['pi'].shape == (4, POLICY_SIZE) + assert batch['z'].shape == (4,) + assert batch['legal_mask'].shape == (4, POLICY_SIZE) + + +def test_openings_batch_handles_small_file(tmp_path): + _write_mock_npz(tmp_path, "openings", "openings_positions.npz", sample_count=1) + manager = DataManager(base_dir=str(tmp_path)) + + batch = manager._get_openings_batch(batch_size=3) + + assert batch is not None + assert batch['s'].shape == (3, 19, 8, 8) + assert batch['pi'].shape == (3, POLICY_SIZE) + assert batch['z'].shape == (3,) + assert batch['legal_mask'].shape == (3, POLICY_SIZE)