Refactor training batch sampling and add coverage#101
Conversation
There was a problem hiding this comment.
Pull Request Overview
This PR refactors the training batch sampling mechanism in DataManager to improve data source mixing and adds comprehensive test coverage. The refactoring changes from a single-loop batch generation approach to dedicated iterators for external and self-play data sources with proper fallback handling.
- Replaces batch-level sampling with individual sample iterators for better mixing control
- Implements fallback logic to fill batches when one iterator is exhausted
- Adds unit test to verify the 30% external data mixing ratio target
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
azchess/data_manager.py |
Refactored get_training_batch method to use separate iterators and added _iter_shard_samples helper method |
tests/test_data_manager.py |
Added new test file with comprehensive test for external/self-play data balancing |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| return | ||
|
|
||
| if len(combined) < batch_size: | ||
| return |
There was a problem hiding this comment.
Early return when combined samples are less than batch_size may cause incomplete batches to be yielded. This could lead to training instability with variable batch sizes. Consider either padding to exact batch_size or ensuring minimum batch requirements are met.
| return | |
| # Pad with dummy samples to ensure consistent batch size | |
| if len(combined) == 0: | |
| return | |
| # Create dummy sample based on the first sample's structure | |
| first_sample = combined[0] | |
| state_shape = first_sample[0].shape | |
| policy_shape = first_sample[1].shape | |
| value_shape = () if np.isscalar(first_sample[2]) else np.shape(first_sample[2]) | |
| legal_mask_shape = first_sample[3].shape if len(first_sample) > 3 and first_sample[3] is not None else None | |
| num_to_pad = batch_size - len(combined) | |
| for _ in range(num_to_pad): | |
| dummy_state = np.zeros(state_shape, dtype=first_sample[0].dtype) | |
| dummy_policy = np.zeros(policy_shape, dtype=first_sample[1].dtype) | |
| dummy_value = np.zeros(value_shape, dtype=np.float32) | |
| if legal_mask_shape is not None: | |
| dummy_legal_mask = np.zeros(legal_mask_shape, dtype=first_sample[3].dtype) | |
| else: | |
| dummy_legal_mask = None | |
| combined.append((dummy_state, dummy_policy, dummy_value, dummy_legal_mask)) |
| if external_ratio > 0 and external_batch_size == 0: | ||
| external_batch_size = 1 | ||
| selfplay_batch_size = max(0, batch_size - external_batch_size) | ||
|
|
||
| for shard_path in shard_paths: | ||
| try: | ||
| # Memory-map to reduce RSS and speed IO | ||
| with np.load(shard_path, mmap_mode='r') as data: | ||
| states, policies, values, legal_mask_all, ssl_targets = self._extract_training_arrays(data) | ||
| if selfplay_ratio > 0 and selfplay_batch_size == 0: | ||
| selfplay_batch_size = 1 | ||
| external_batch_size = max(0, batch_size - selfplay_batch_size) |
There was a problem hiding this comment.
The logic for ensuring minimum batch sizes when ratios are non-zero is duplicated and could lead to inconsistent state. Consider consolidating this logic into a single validation step or extracting it into a helper method.
| batch = next(generator) | ||
| states = batch[0] | ||
| total_samples += states.shape[0] | ||
| external_samples += int(np.sum(np.isclose(states[:, 0, 0, 0], 2.0))) |
There was a problem hiding this comment.
Using magic number 2.0 to identify external samples makes the test brittle. Consider using a named constant like EXTERNAL_FILL_VALUE = 2.0 to make the test logic clearer and more maintainable.
There was a problem hiding this comment.
💡 Codex Review
Matrix0/azchess/data_manager.py
Lines 271 to 316 in df70f09
The new batching logic only classifies shards as external when source contains 'stockfish' and as self-play when the source is 'selfplay' or empty. Shards imported via import_replay_dir often carry other labels such as 'external' or teacher:* (see orchestrator.import_replay_dir), but they are now dropped from both lists. If a run relies solely on these external/teacher shards, get_training_batch will raise RuntimeError("No valid training data available") even though data exists, and when mixed with self-play shards those external shards are never sampled. This regression prevents training pipelines that ingest non-stockfish external data from working.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting
Summary
DataManager.get_training_batchto draw from dedicated external and self-play shard iteratorsTesting
https://chatgpt.com/codex/tasks/task_e_68e6f98fd1d48323a4726f10a01a695d