-
Notifications
You must be signed in to change notification settings - Fork 0
Refactor training batch sampling and add coverage #101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -278,112 +278,188 @@ def get_training_batch(self, batch_size: int, device: str = "cpu") -> Iterator[T | |||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Target 30% external data, 70% self-play for balanced learning | ||||||||||||||||||||||||||||||||||||||||||
| target_external_ratio = 0.3 | ||||||||||||||||||||||||||||||||||||||||||
| target_selfplay_ratio = 0.7 | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Adjust ratios based on available data | ||||||||||||||||||||||||||||||||||||||||||
| if total_external_samples == 0: | ||||||||||||||||||||||||||||||||||||||||||
| # No external data available, use all self-play | ||||||||||||||||||||||||||||||||||||||||||
| external_ratio = 0.0 | ||||||||||||||||||||||||||||||||||||||||||
| selfplay_ratio = 1.0 | ||||||||||||||||||||||||||||||||||||||||||
| elif total_selfplay_samples == 0: | ||||||||||||||||||||||||||||||||||||||||||
| # No self-play data, use all external | ||||||||||||||||||||||||||||||||||||||||||
| external_ratio = 1.0 | ||||||||||||||||||||||||||||||||||||||||||
| selfplay_ratio = 0.0 | ||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||
| # Balance the ratios | ||||||||||||||||||||||||||||||||||||||||||
| external_ratio = min(target_external_ratio, total_external_samples / (total_external_samples + total_selfplay_samples)) | ||||||||||||||||||||||||||||||||||||||||||
| selfplay_ratio = 1.0 - external_ratio | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Calculate samples per source | ||||||||||||||||||||||||||||||||||||||||||
| external_batch_size = max(1, int(batch_size * external_ratio)) | ||||||||||||||||||||||||||||||||||||||||||
| selfplay_batch_size = batch_size - external_batch_size | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Get shard paths with balanced selection | ||||||||||||||||||||||||||||||||||||||||||
| external_paths = [s.path for s in external_shards] | ||||||||||||||||||||||||||||||||||||||||||
| selfplay_paths = [s.path for s in selfplay_shards] | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Combine with appropriate weighting | ||||||||||||||||||||||||||||||||||||||||||
| if external_paths and selfplay_paths: | ||||||||||||||||||||||||||||||||||||||||||
| # Both sources available - create weighted sampling | ||||||||||||||||||||||||||||||||||||||||||
| shard_paths = (external_paths * max(1, len(selfplay_paths) // len(external_paths) if external_paths else 1) + | ||||||||||||||||||||||||||||||||||||||||||
| selfplay_paths * max(1, len(external_paths) // len(selfplay_paths) if selfplay_paths else 1)) | ||||||||||||||||||||||||||||||||||||||||||
| elif external_paths: | ||||||||||||||||||||||||||||||||||||||||||
| shard_paths = external_paths | ||||||||||||||||||||||||||||||||||||||||||
| if not external_shards: | ||||||||||||||||||||||||||||||||||||||||||
| external_batch_size = 0 | ||||||||||||||||||||||||||||||||||||||||||
| selfplay_batch_size = batch_size | ||||||||||||||||||||||||||||||||||||||||||
| elif not selfplay_shards: | ||||||||||||||||||||||||||||||||||||||||||
| external_batch_size = batch_size | ||||||||||||||||||||||||||||||||||||||||||
| selfplay_batch_size = 0 | ||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||
| shard_paths = selfplay_paths | ||||||||||||||||||||||||||||||||||||||||||
| external_batch_size = int(round(batch_size * external_ratio)) | ||||||||||||||||||||||||||||||||||||||||||
| external_batch_size = min(batch_size, external_batch_size) | ||||||||||||||||||||||||||||||||||||||||||
| selfplay_batch_size = batch_size - external_batch_size | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Randomly sample from combined list | ||||||||||||||||||||||||||||||||||||||||||
| np.random.shuffle(shard_paths) | ||||||||||||||||||||||||||||||||||||||||||
| if external_ratio > 0 and external_batch_size == 0: | ||||||||||||||||||||||||||||||||||||||||||
| external_batch_size = 1 | ||||||||||||||||||||||||||||||||||||||||||
| selfplay_batch_size = max(0, batch_size - external_batch_size) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| for shard_path in shard_paths: | ||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||
| # Memory-map to reduce RSS and speed IO | ||||||||||||||||||||||||||||||||||||||||||
| with np.load(shard_path, mmap_mode='r') as data: | ||||||||||||||||||||||||||||||||||||||||||
| states, policies, values, legal_mask_all, ssl_targets = self._extract_training_arrays(data) | ||||||||||||||||||||||||||||||||||||||||||
| if selfplay_ratio > 0 and selfplay_batch_size == 0: | ||||||||||||||||||||||||||||||||||||||||||
| selfplay_batch_size = 1 | ||||||||||||||||||||||||||||||||||||||||||
| external_batch_size = max(0, batch_size - selfplay_batch_size) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Normalize common shape variants proactively | ||||||||||||||||||||||||||||||||||||||||||
| # values can be (N,) or (N,1); normalize to (N,) | ||||||||||||||||||||||||||||||||||||||||||
| if values.ndim == 2 and values.shape[1] == 1: | ||||||||||||||||||||||||||||||||||||||||||
| values = values.reshape(values.shape[0]) | ||||||||||||||||||||||||||||||||||||||||||
| external_iter = self._iter_shard_samples(external_shards) if external_shards else None | ||||||||||||||||||||||||||||||||||||||||||
| selfplay_iter = self._iter_shard_samples(selfplay_shards) if selfplay_shards else None | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| if not self._validate_shapes(states, policies, values, self.expected_planes, shard_path): | ||||||||||||||||||||||||||||||||||||||||||
| self._mark_shard_corrupted(shard_path) | ||||||||||||||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||||||||||||||
| if external_iter is None and selfplay_iter is None: | ||||||||||||||||||||||||||||||||||||||||||
| raise RuntimeError("No valid training data available") | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Normalize legal mask to shape (N, 4672) and dtype uint8/bool | ||||||||||||||||||||||||||||||||||||||||||
| if legal_mask_all is not None: | ||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||
| if legal_mask_all.ndim > 2: | ||||||||||||||||||||||||||||||||||||||||||
| legal_mask_all = legal_mask_all.reshape(legal_mask_all.shape[0], -1) | ||||||||||||||||||||||||||||||||||||||||||
| # Cast to uint8 to minimize memory; convert to bool later if needed | ||||||||||||||||||||||||||||||||||||||||||
| if legal_mask_all.dtype != np.uint8: | ||||||||||||||||||||||||||||||||||||||||||
| legal_mask_all = legal_mask_all.astype(np.uint8, copy=False) | ||||||||||||||||||||||||||||||||||||||||||
| except Exception: | ||||||||||||||||||||||||||||||||||||||||||
| legal_mask_all = None | ||||||||||||||||||||||||||||||||||||||||||
| def _collect_samples(sample_iter: Optional[Iterator[Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]]], count: int): | ||||||||||||||||||||||||||||||||||||||||||
| collected: List[Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]] = [] | ||||||||||||||||||||||||||||||||||||||||||
| if sample_iter is None or count <= 0: | ||||||||||||||||||||||||||||||||||||||||||
| return collected, sample_iter | ||||||||||||||||||||||||||||||||||||||||||
| while len(collected) < count: | ||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||
| collected.append(next(sample_iter)) | ||||||||||||||||||||||||||||||||||||||||||
| except StopIteration: | ||||||||||||||||||||||||||||||||||||||||||
| sample_iter = None | ||||||||||||||||||||||||||||||||||||||||||
| break | ||||||||||||||||||||||||||||||||||||||||||
| return collected, sample_iter | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Shuffle within shard | ||||||||||||||||||||||||||||||||||||||||||
| indices = np.random.permutation(len(states)) | ||||||||||||||||||||||||||||||||||||||||||
| states = states[indices] | ||||||||||||||||||||||||||||||||||||||||||
| policies = policies[indices] | ||||||||||||||||||||||||||||||||||||||||||
| values = values[indices] | ||||||||||||||||||||||||||||||||||||||||||
| if legal_mask_all is not None: | ||||||||||||||||||||||||||||||||||||||||||
| legal_mask_all = legal_mask_all[indices] | ||||||||||||||||||||||||||||||||||||||||||
| while True: | ||||||||||||||||||||||||||||||||||||||||||
| ext_target = external_batch_size if external_iter else 0 | ||||||||||||||||||||||||||||||||||||||||||
| sp_target = selfplay_batch_size if selfplay_iter else 0 | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| if ext_target + sp_target == 0: | ||||||||||||||||||||||||||||||||||||||||||
| if external_iter: | ||||||||||||||||||||||||||||||||||||||||||
| ext_target = batch_size | ||||||||||||||||||||||||||||||||||||||||||
| elif selfplay_iter: | ||||||||||||||||||||||||||||||||||||||||||
| sp_target = batch_size | ||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| ext_samples, external_iter = _collect_samples(external_iter, ext_target) | ||||||||||||||||||||||||||||||||||||||||||
| sp_samples, selfplay_iter = _collect_samples(selfplay_iter, sp_target) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| combined = ext_samples + sp_samples | ||||||||||||||||||||||||||||||||||||||||||
| remaining = batch_size - len(combined) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| if remaining > 0 and external_iter: | ||||||||||||||||||||||||||||||||||||||||||
| extra, external_iter = _collect_samples(external_iter, remaining) | ||||||||||||||||||||||||||||||||||||||||||
| combined.extend(extra) | ||||||||||||||||||||||||||||||||||||||||||
| remaining = batch_size - len(combined) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| if remaining > 0 and selfplay_iter: | ||||||||||||||||||||||||||||||||||||||||||
| extra, selfplay_iter = _collect_samples(selfplay_iter, remaining) | ||||||||||||||||||||||||||||||||||||||||||
| combined.extend(extra) | ||||||||||||||||||||||||||||||||||||||||||
| remaining = batch_size - len(combined) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| if not combined: | ||||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| if len(combined) < batch_size: | ||||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
| return | |
| # Pad with dummy samples to ensure consistent batch size | |
| if len(combined) == 0: | |
| return | |
| # Create dummy sample based on the first sample's structure | |
| first_sample = combined[0] | |
| state_shape = first_sample[0].shape | |
| policy_shape = first_sample[1].shape | |
| value_shape = () if np.isscalar(first_sample[2]) else np.shape(first_sample[2]) | |
| legal_mask_shape = first_sample[3].shape if len(first_sample) > 3 and first_sample[3] is not None else None | |
| num_to_pad = batch_size - len(combined) | |
| for _ in range(num_to_pad): | |
| dummy_state = np.zeros(state_shape, dtype=first_sample[0].dtype) | |
| dummy_policy = np.zeros(policy_shape, dtype=first_sample[1].dtype) | |
| dummy_value = np.zeros(value_shape, dtype=np.float32) | |
| if legal_mask_shape is not None: | |
| dummy_legal_mask = np.zeros(legal_mask_shape, dtype=first_sample[3].dtype) | |
| else: | |
| dummy_legal_mask = None | |
| combined.append((dummy_state, dummy_policy, dummy_value, dummy_legal_mask)) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| import numpy as np | ||
|
|
||
| from azchess.data_manager import DataManager | ||
|
|
||
|
|
||
| def _make_shard_data(num_samples: int, fill_value: float) -> dict[str, np.ndarray]: | ||
| states = np.full((num_samples, 19, 8, 8), fill_value, dtype=np.float32) | ||
| policies = np.full((num_samples, 4672), fill_value, dtype=np.float32) | ||
| values = np.full((num_samples,), fill_value, dtype=np.float32) | ||
| return {'s': states, 'pi': policies, 'z': values} | ||
|
|
||
|
|
||
| def test_training_batch_balances_external_and_selfplay(tmp_path): | ||
| np.random.seed(0) | ||
|
|
||
| manager = DataManager(base_dir=str(tmp_path)) | ||
|
|
||
| manager.add_training_data(_make_shard_data(16, 1.0), shard_id=0, source="selfplay") | ||
| manager.add_training_data(_make_shard_data(8, 2.0), shard_id=1, source="stockfish:mixed") | ||
|
|
||
| batch_size = 10 | ||
| generator = manager.get_training_batch(batch_size) | ||
|
|
||
| total_samples = 0 | ||
| external_samples = 0 | ||
| num_batches = 12 | ||
|
|
||
| for _ in range(num_batches): | ||
| batch = next(generator) | ||
| states = batch[0] | ||
| total_samples += states.shape[0] | ||
| external_samples += int(np.sum(np.isclose(states[:, 0, 0, 0], 2.0))) | ||
|
||
|
|
||
| observed_ratio = external_samples / total_samples | ||
|
|
||
| assert np.isclose(observed_ratio, 0.3, atol=0.05) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for ensuring minimum batch sizes when ratios are non-zero is duplicated and could lead to inconsistent state. Consider consolidating this logic into a single validation step or extracting it into a helper method.