Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 158 additions & 82 deletions azchess/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,112 +278,188 @@ def get_training_batch(self, batch_size: int, device: str = "cpu") -> Iterator[T

# Target 30% external data, 70% self-play for balanced learning
target_external_ratio = 0.3
target_selfplay_ratio = 0.7

# Adjust ratios based on available data
if total_external_samples == 0:
# No external data available, use all self-play
external_ratio = 0.0
selfplay_ratio = 1.0
elif total_selfplay_samples == 0:
# No self-play data, use all external
external_ratio = 1.0
selfplay_ratio = 0.0
else:
# Balance the ratios
external_ratio = min(target_external_ratio, total_external_samples / (total_external_samples + total_selfplay_samples))
selfplay_ratio = 1.0 - external_ratio

# Calculate samples per source
external_batch_size = max(1, int(batch_size * external_ratio))
selfplay_batch_size = batch_size - external_batch_size

# Get shard paths with balanced selection
external_paths = [s.path for s in external_shards]
selfplay_paths = [s.path for s in selfplay_shards]

# Combine with appropriate weighting
if external_paths and selfplay_paths:
# Both sources available - create weighted sampling
shard_paths = (external_paths * max(1, len(selfplay_paths) // len(external_paths) if external_paths else 1) +
selfplay_paths * max(1, len(external_paths) // len(selfplay_paths) if selfplay_paths else 1))
elif external_paths:
shard_paths = external_paths
if not external_shards:
external_batch_size = 0
selfplay_batch_size = batch_size
elif not selfplay_shards:
external_batch_size = batch_size
selfplay_batch_size = 0
else:
shard_paths = selfplay_paths
external_batch_size = int(round(batch_size * external_ratio))
external_batch_size = min(batch_size, external_batch_size)
selfplay_batch_size = batch_size - external_batch_size

# Randomly sample from combined list
np.random.shuffle(shard_paths)
if external_ratio > 0 and external_batch_size == 0:
external_batch_size = 1
selfplay_batch_size = max(0, batch_size - external_batch_size)

for shard_path in shard_paths:
try:
# Memory-map to reduce RSS and speed IO
with np.load(shard_path, mmap_mode='r') as data:
states, policies, values, legal_mask_all, ssl_targets = self._extract_training_arrays(data)
if selfplay_ratio > 0 and selfplay_batch_size == 0:
selfplay_batch_size = 1
external_batch_size = max(0, batch_size - selfplay_batch_size)
Comment on lines +304 to +310
Copy link

Copilot AI Oct 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic for ensuring minimum batch sizes when ratios are non-zero is duplicated and could lead to inconsistent state. Consider consolidating this logic into a single validation step or extracting it into a helper method.

Copilot uses AI. Check for mistakes.

# Normalize common shape variants proactively
# values can be (N,) or (N,1); normalize to (N,)
if values.ndim == 2 and values.shape[1] == 1:
values = values.reshape(values.shape[0])
external_iter = self._iter_shard_samples(external_shards) if external_shards else None
selfplay_iter = self._iter_shard_samples(selfplay_shards) if selfplay_shards else None

if not self._validate_shapes(states, policies, values, self.expected_planes, shard_path):
self._mark_shard_corrupted(shard_path)
continue
if external_iter is None and selfplay_iter is None:
raise RuntimeError("No valid training data available")

# Normalize legal mask to shape (N, 4672) and dtype uint8/bool
if legal_mask_all is not None:
try:
if legal_mask_all.ndim > 2:
legal_mask_all = legal_mask_all.reshape(legal_mask_all.shape[0], -1)
# Cast to uint8 to minimize memory; convert to bool later if needed
if legal_mask_all.dtype != np.uint8:
legal_mask_all = legal_mask_all.astype(np.uint8, copy=False)
except Exception:
legal_mask_all = None
def _collect_samples(sample_iter: Optional[Iterator[Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]]], count: int):
collected: List[Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]] = []
if sample_iter is None or count <= 0:
return collected, sample_iter
while len(collected) < count:
try:
collected.append(next(sample_iter))
except StopIteration:
sample_iter = None
break
return collected, sample_iter

# Shuffle within shard
indices = np.random.permutation(len(states))
states = states[indices]
policies = policies[indices]
values = values[indices]
if legal_mask_all is not None:
legal_mask_all = legal_mask_all[indices]
while True:
ext_target = external_batch_size if external_iter else 0
sp_target = selfplay_batch_size if selfplay_iter else 0

if ext_target + sp_target == 0:
if external_iter:
ext_target = batch_size
elif selfplay_iter:
sp_target = batch_size
else:
return

ext_samples, external_iter = _collect_samples(external_iter, ext_target)
sp_samples, selfplay_iter = _collect_samples(selfplay_iter, sp_target)

combined = ext_samples + sp_samples
remaining = batch_size - len(combined)

if remaining > 0 and external_iter:
extra, external_iter = _collect_samples(external_iter, remaining)
combined.extend(extra)
remaining = batch_size - len(combined)

if remaining > 0 and selfplay_iter:
extra, selfplay_iter = _collect_samples(selfplay_iter, remaining)
combined.extend(extra)
remaining = batch_size - len(combined)

if not combined:
return

if len(combined) < batch_size:
return
Copy link

Copilot AI Oct 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Early return when combined samples are less than batch_size may cause incomplete batches to be yielded. This could lead to training instability with variable batch sizes. Consider either padding to exact batch_size or ensuring minimum batch requirements are met.

Suggested change
return
# Pad with dummy samples to ensure consistent batch size
if len(combined) == 0:
return
# Create dummy sample based on the first sample's structure
first_sample = combined[0]
state_shape = first_sample[0].shape
policy_shape = first_sample[1].shape
value_shape = () if np.isscalar(first_sample[2]) else np.shape(first_sample[2])
legal_mask_shape = first_sample[3].shape if len(first_sample) > 3 and first_sample[3] is not None else None
num_to_pad = batch_size - len(combined)
for _ in range(num_to_pad):
dummy_state = np.zeros(state_shape, dtype=first_sample[0].dtype)
dummy_policy = np.zeros(policy_shape, dtype=first_sample[1].dtype)
dummy_value = np.zeros(value_shape, dtype=np.float32)
if legal_mask_shape is not None:
dummy_legal_mask = np.zeros(legal_mask_shape, dtype=first_sample[3].dtype)
else:
dummy_legal_mask = None
combined.append((dummy_state, dummy_policy, dummy_value, dummy_legal_mask))

Copilot uses AI. Check for mistakes.

states = np.stack([sample[0] for sample in combined], axis=0)
policies = np.stack([sample[1] for sample in combined], axis=0)
values = np.array([sample[2] for sample in combined], dtype=np.float32)

if values.ndim == 2 and values.shape[1] == 1:
values = values.reshape(values.shape[0])

legal_entries = [sample[3] for sample in combined]
if all(mask is not None for mask in legal_entries):
legal_mask = np.stack(legal_entries, axis=0)
if legal_mask.ndim > 2:
legal_mask = legal_mask.reshape(legal_mask.shape[0], -1)
if legal_mask.dtype != np.uint8:
legal_mask = legal_mask.astype(np.uint8, copy=False)
if not legal_mask.flags['C_CONTIGUOUS']:
legal_mask = np.ascontiguousarray(legal_mask)
else:
legal_mask = None

# Shuffle SSL targets with same indices
for key in ssl_targets:
ssl_targets[key] = ssl_targets[key][indices]
states = np.ascontiguousarray(states, dtype=np.float32)
policies = np.ascontiguousarray(policies, dtype=np.float32)
values = np.ascontiguousarray(values, dtype=np.float32)

# Yield batches
for i in range(0, len(states), batch_size):
batch_states = states[i:i+batch_size]
batch_policies = policies[i:i+batch_size]
batch_values = values[i:i+batch_size]
batch_legal = None
if legal_mask_all is not None:
batch_legal = legal_mask_all[i:i+batch_size]
if legal_mask is not None:
yield (states, policies, values, legal_mask)
else:
yield (states, policies, values)

# Ensure contiguous memory and dtypes before yielding
if not isinstance(batch_states, np.ndarray) or not batch_states.flags['C_CONTIGUOUS']:
batch_states = np.ascontiguousarray(batch_states)
if not batch_policies.flags['C_CONTIGUOUS']:
batch_policies = np.ascontiguousarray(batch_policies)
if not batch_values.flags['C_CONTIGUOUS']:
batch_values = np.ascontiguousarray(batch_values)
def _iter_shard_samples(self, shards: List[DataShard]) -> Optional[Iterator[Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]]]:
if not shards:
return None

if batch_states.dtype != np.float32:
batch_states = batch_states.astype(np.float32, copy=False)
if batch_policies.dtype != np.float32:
batch_policies = batch_policies.astype(np.float32, copy=False)
if batch_values.dtype != np.float32:
batch_values = batch_values.astype(np.float32, copy=False)
shard_paths = [s.path for s in shards]

# Only yield tuples expected by train_step: (s, pi, z[, legal_mask])
yield (batch_states, batch_policies, batch_values, batch_legal) if batch_legal is not None else (batch_states, batch_policies, batch_values)

except Exception as e:
logger.error(f"Error loading shard {shard_path}: {e}", exc_info=True)
self._mark_shard_corrupted(shard_path)
continue
def generator() -> Iterator[Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]]:
local_paths = list(shard_paths)
while True:
np.random.shuffle(local_paths)
made_progress = False
for shard_path in local_paths:
try:
with np.load(shard_path, mmap_mode='r') as data:
states, policies, values, legal_mask_all, ssl_targets = self._extract_training_arrays(data)

if values.ndim == 2 and values.shape[1] == 1:
values = values.reshape(values.shape[0])

if not self._validate_shapes(states, policies, values, self.expected_planes, shard_path):
self._mark_shard_corrupted(shard_path)
continue

legal_mask_processed: Optional[np.ndarray] = None
if legal_mask_all is not None:
try:
if legal_mask_all.ndim > 2:
legal_mask_all = legal_mask_all.reshape(legal_mask_all.shape[0], -1)
if legal_mask_all.dtype != np.uint8:
legal_mask_all = legal_mask_all.astype(np.uint8, copy=False)
legal_mask_processed = legal_mask_all
except Exception:
legal_mask_processed = None

indices = np.random.permutation(len(states))
states = np.ascontiguousarray(states[indices])
policies = np.ascontiguousarray(policies[indices])
values = np.ascontiguousarray(values[indices])

if states.dtype != np.float32:
states = states.astype(np.float32, copy=False)
if policies.dtype != np.float32:
policies = policies.astype(np.float32, copy=False)
if values.dtype != np.float32:
values = values.astype(np.float32, copy=False)

legal_batches: Optional[np.ndarray] = None
if legal_mask_processed is not None:
legal_mask_processed = legal_mask_processed[indices]
if not legal_mask_processed.flags['C_CONTIGUOUS']:
legal_mask_processed = np.ascontiguousarray(legal_mask_processed)
legal_batches = legal_mask_processed

for key in ssl_targets:
ssl_targets[key] = ssl_targets[key][indices]

for idx in range(len(states)):
made_progress = True
legal_entry: Optional[np.ndarray] = None
if legal_batches is not None:
legal_entry = legal_batches[idx]
yield (states[idx], policies[idx], values[idx], legal_entry)
except Exception as e:
logger.error(f"Error loading shard {shard_path}: {e}", exc_info=True)
self._mark_shard_corrupted(shard_path)
continue

if not made_progress:
return

return generator()

def get_external_training_batch(self, batch_size: int, source: str = "mixed") -> Optional[Dict[str, np.ndarray]]:
"""Get training batches from external training data sources.
Expand Down
36 changes: 36 additions & 0 deletions tests/test_data_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import numpy as np

from azchess.data_manager import DataManager


def _make_shard_data(num_samples: int, fill_value: float) -> dict[str, np.ndarray]:
states = np.full((num_samples, 19, 8, 8), fill_value, dtype=np.float32)
policies = np.full((num_samples, 4672), fill_value, dtype=np.float32)
values = np.full((num_samples,), fill_value, dtype=np.float32)
return {'s': states, 'pi': policies, 'z': values}


def test_training_batch_balances_external_and_selfplay(tmp_path):
np.random.seed(0)

manager = DataManager(base_dir=str(tmp_path))

manager.add_training_data(_make_shard_data(16, 1.0), shard_id=0, source="selfplay")
manager.add_training_data(_make_shard_data(8, 2.0), shard_id=1, source="stockfish:mixed")

batch_size = 10
generator = manager.get_training_batch(batch_size)

total_samples = 0
external_samples = 0
num_batches = 12

for _ in range(num_batches):
batch = next(generator)
states = batch[0]
total_samples += states.shape[0]
external_samples += int(np.sum(np.isclose(states[:, 0, 0, 0], 2.0)))
Copy link

Copilot AI Oct 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using magic number 2.0 to identify external samples makes the test brittle. Consider using a named constant like EXTERNAL_FILL_VALUE = 2.0 to make the test logic clearer and more maintainable.

Copilot uses AI. Check for mistakes.

observed_ratio = external_samples / total_samples

assert np.isclose(observed_ratio, 0.3, atol=0.05)
Loading