diff --git a/azchess/data_manager.py b/azchess/data_manager.py index f1d43e9..ca94a28 100644 --- a/azchess/data_manager.py +++ b/azchess/data_manager.py @@ -278,112 +278,188 @@ def get_training_batch(self, batch_size: int, device: str = "cpu") -> Iterator[T # Target 30% external data, 70% self-play for balanced learning target_external_ratio = 0.3 - target_selfplay_ratio = 0.7 # Adjust ratios based on available data if total_external_samples == 0: - # No external data available, use all self-play external_ratio = 0.0 selfplay_ratio = 1.0 elif total_selfplay_samples == 0: - # No self-play data, use all external external_ratio = 1.0 selfplay_ratio = 0.0 else: - # Balance the ratios external_ratio = min(target_external_ratio, total_external_samples / (total_external_samples + total_selfplay_samples)) selfplay_ratio = 1.0 - external_ratio - # Calculate samples per source - external_batch_size = max(1, int(batch_size * external_ratio)) - selfplay_batch_size = batch_size - external_batch_size - - # Get shard paths with balanced selection - external_paths = [s.path for s in external_shards] - selfplay_paths = [s.path for s in selfplay_shards] - - # Combine with appropriate weighting - if external_paths and selfplay_paths: - # Both sources available - create weighted sampling - shard_paths = (external_paths * max(1, len(selfplay_paths) // len(external_paths) if external_paths else 1) + - selfplay_paths * max(1, len(external_paths) // len(selfplay_paths) if selfplay_paths else 1)) - elif external_paths: - shard_paths = external_paths + if not external_shards: + external_batch_size = 0 + selfplay_batch_size = batch_size + elif not selfplay_shards: + external_batch_size = batch_size + selfplay_batch_size = 0 else: - shard_paths = selfplay_paths + external_batch_size = int(round(batch_size * external_ratio)) + external_batch_size = min(batch_size, external_batch_size) + selfplay_batch_size = batch_size - external_batch_size - # Randomly sample from combined list - np.random.shuffle(shard_paths) + if external_ratio > 0 and external_batch_size == 0: + external_batch_size = 1 + selfplay_batch_size = max(0, batch_size - external_batch_size) - for shard_path in shard_paths: - try: - # Memory-map to reduce RSS and speed IO - with np.load(shard_path, mmap_mode='r') as data: - states, policies, values, legal_mask_all, ssl_targets = self._extract_training_arrays(data) + if selfplay_ratio > 0 and selfplay_batch_size == 0: + selfplay_batch_size = 1 + external_batch_size = max(0, batch_size - selfplay_batch_size) - # Normalize common shape variants proactively - # values can be (N,) or (N,1); normalize to (N,) - if values.ndim == 2 and values.shape[1] == 1: - values = values.reshape(values.shape[0]) + external_iter = self._iter_shard_samples(external_shards) if external_shards else None + selfplay_iter = self._iter_shard_samples(selfplay_shards) if selfplay_shards else None - if not self._validate_shapes(states, policies, values, self.expected_planes, shard_path): - self._mark_shard_corrupted(shard_path) - continue + if external_iter is None and selfplay_iter is None: + raise RuntimeError("No valid training data available") - # Normalize legal mask to shape (N, 4672) and dtype uint8/bool - if legal_mask_all is not None: - try: - if legal_mask_all.ndim > 2: - legal_mask_all = legal_mask_all.reshape(legal_mask_all.shape[0], -1) - # Cast to uint8 to minimize memory; convert to bool later if needed - if legal_mask_all.dtype != np.uint8: - legal_mask_all = legal_mask_all.astype(np.uint8, copy=False) - except Exception: - legal_mask_all = None + def _collect_samples(sample_iter: Optional[Iterator[Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]]], count: int): + collected: List[Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]] = [] + if sample_iter is None or count <= 0: + return collected, sample_iter + while len(collected) < count: + try: + collected.append(next(sample_iter)) + except StopIteration: + sample_iter = None + break + return collected, sample_iter - # Shuffle within shard - indices = np.random.permutation(len(states)) - states = states[indices] - policies = policies[indices] - values = values[indices] - if legal_mask_all is not None: - legal_mask_all = legal_mask_all[indices] + while True: + ext_target = external_batch_size if external_iter else 0 + sp_target = selfplay_batch_size if selfplay_iter else 0 + + if ext_target + sp_target == 0: + if external_iter: + ext_target = batch_size + elif selfplay_iter: + sp_target = batch_size + else: + return + + ext_samples, external_iter = _collect_samples(external_iter, ext_target) + sp_samples, selfplay_iter = _collect_samples(selfplay_iter, sp_target) + + combined = ext_samples + sp_samples + remaining = batch_size - len(combined) + + if remaining > 0 and external_iter: + extra, external_iter = _collect_samples(external_iter, remaining) + combined.extend(extra) + remaining = batch_size - len(combined) + + if remaining > 0 and selfplay_iter: + extra, selfplay_iter = _collect_samples(selfplay_iter, remaining) + combined.extend(extra) + remaining = batch_size - len(combined) + + if not combined: + return + + if len(combined) < batch_size: + return + + states = np.stack([sample[0] for sample in combined], axis=0) + policies = np.stack([sample[1] for sample in combined], axis=0) + values = np.array([sample[2] for sample in combined], dtype=np.float32) + + if values.ndim == 2 and values.shape[1] == 1: + values = values.reshape(values.shape[0]) + + legal_entries = [sample[3] for sample in combined] + if all(mask is not None for mask in legal_entries): + legal_mask = np.stack(legal_entries, axis=0) + if legal_mask.ndim > 2: + legal_mask = legal_mask.reshape(legal_mask.shape[0], -1) + if legal_mask.dtype != np.uint8: + legal_mask = legal_mask.astype(np.uint8, copy=False) + if not legal_mask.flags['C_CONTIGUOUS']: + legal_mask = np.ascontiguousarray(legal_mask) + else: + legal_mask = None - # Shuffle SSL targets with same indices - for key in ssl_targets: - ssl_targets[key] = ssl_targets[key][indices] + states = np.ascontiguousarray(states, dtype=np.float32) + policies = np.ascontiguousarray(policies, dtype=np.float32) + values = np.ascontiguousarray(values, dtype=np.float32) - # Yield batches - for i in range(0, len(states), batch_size): - batch_states = states[i:i+batch_size] - batch_policies = policies[i:i+batch_size] - batch_values = values[i:i+batch_size] - batch_legal = None - if legal_mask_all is not None: - batch_legal = legal_mask_all[i:i+batch_size] + if legal_mask is not None: + yield (states, policies, values, legal_mask) + else: + yield (states, policies, values) - # Ensure contiguous memory and dtypes before yielding - if not isinstance(batch_states, np.ndarray) or not batch_states.flags['C_CONTIGUOUS']: - batch_states = np.ascontiguousarray(batch_states) - if not batch_policies.flags['C_CONTIGUOUS']: - batch_policies = np.ascontiguousarray(batch_policies) - if not batch_values.flags['C_CONTIGUOUS']: - batch_values = np.ascontiguousarray(batch_values) + def _iter_shard_samples(self, shards: List[DataShard]) -> Optional[Iterator[Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]]]: + if not shards: + return None - if batch_states.dtype != np.float32: - batch_states = batch_states.astype(np.float32, copy=False) - if batch_policies.dtype != np.float32: - batch_policies = batch_policies.astype(np.float32, copy=False) - if batch_values.dtype != np.float32: - batch_values = batch_values.astype(np.float32, copy=False) + shard_paths = [s.path for s in shards] - # Only yield tuples expected by train_step: (s, pi, z[, legal_mask]) - yield (batch_states, batch_policies, batch_values, batch_legal) if batch_legal is not None else (batch_states, batch_policies, batch_values) - - except Exception as e: - logger.error(f"Error loading shard {shard_path}: {e}", exc_info=True) - self._mark_shard_corrupted(shard_path) - continue + def generator() -> Iterator[Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]]: + local_paths = list(shard_paths) + while True: + np.random.shuffle(local_paths) + made_progress = False + for shard_path in local_paths: + try: + with np.load(shard_path, mmap_mode='r') as data: + states, policies, values, legal_mask_all, ssl_targets = self._extract_training_arrays(data) + + if values.ndim == 2 and values.shape[1] == 1: + values = values.reshape(values.shape[0]) + + if not self._validate_shapes(states, policies, values, self.expected_planes, shard_path): + self._mark_shard_corrupted(shard_path) + continue + + legal_mask_processed: Optional[np.ndarray] = None + if legal_mask_all is not None: + try: + if legal_mask_all.ndim > 2: + legal_mask_all = legal_mask_all.reshape(legal_mask_all.shape[0], -1) + if legal_mask_all.dtype != np.uint8: + legal_mask_all = legal_mask_all.astype(np.uint8, copy=False) + legal_mask_processed = legal_mask_all + except Exception: + legal_mask_processed = None + + indices = np.random.permutation(len(states)) + states = np.ascontiguousarray(states[indices]) + policies = np.ascontiguousarray(policies[indices]) + values = np.ascontiguousarray(values[indices]) + + if states.dtype != np.float32: + states = states.astype(np.float32, copy=False) + if policies.dtype != np.float32: + policies = policies.astype(np.float32, copy=False) + if values.dtype != np.float32: + values = values.astype(np.float32, copy=False) + + legal_batches: Optional[np.ndarray] = None + if legal_mask_processed is not None: + legal_mask_processed = legal_mask_processed[indices] + if not legal_mask_processed.flags['C_CONTIGUOUS']: + legal_mask_processed = np.ascontiguousarray(legal_mask_processed) + legal_batches = legal_mask_processed + + for key in ssl_targets: + ssl_targets[key] = ssl_targets[key][indices] + + for idx in range(len(states)): + made_progress = True + legal_entry: Optional[np.ndarray] = None + if legal_batches is not None: + legal_entry = legal_batches[idx] + yield (states[idx], policies[idx], values[idx], legal_entry) + except Exception as e: + logger.error(f"Error loading shard {shard_path}: {e}", exc_info=True) + self._mark_shard_corrupted(shard_path) + continue + + if not made_progress: + return + + return generator() def get_external_training_batch(self, batch_size: int, source: str = "mixed") -> Optional[Dict[str, np.ndarray]]: """Get training batches from external training data sources. diff --git a/tests/test_data_manager.py b/tests/test_data_manager.py new file mode 100644 index 0000000..b1fb298 --- /dev/null +++ b/tests/test_data_manager.py @@ -0,0 +1,36 @@ +import numpy as np + +from azchess.data_manager import DataManager + + +def _make_shard_data(num_samples: int, fill_value: float) -> dict[str, np.ndarray]: + states = np.full((num_samples, 19, 8, 8), fill_value, dtype=np.float32) + policies = np.full((num_samples, 4672), fill_value, dtype=np.float32) + values = np.full((num_samples,), fill_value, dtype=np.float32) + return {'s': states, 'pi': policies, 'z': values} + + +def test_training_batch_balances_external_and_selfplay(tmp_path): + np.random.seed(0) + + manager = DataManager(base_dir=str(tmp_path)) + + manager.add_training_data(_make_shard_data(16, 1.0), shard_id=0, source="selfplay") + manager.add_training_data(_make_shard_data(8, 2.0), shard_id=1, source="stockfish:mixed") + + batch_size = 10 + generator = manager.get_training_batch(batch_size) + + total_samples = 0 + external_samples = 0 + num_batches = 12 + + for _ in range(num_batches): + batch = next(generator) + states = batch[0] + total_samples += states.shape[0] + external_samples += int(np.sum(np.isclose(states[:, 0, 0, 0], 2.0))) + + observed_ratio = external_samples / total_samples + + assert np.isclose(observed_ratio, 0.3, atol=0.05)