Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 56 additions & 3 deletions azchess/selfplay/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,32 @@ def _recreate_worker_events(shared_memory_resources: List[Dict[str, Any]], attem
logger.error(f"Failed to recreate events for worker {i}: {recreate_error}")


def _get_policy_width_from_resource(resource: Dict[str, Any]) -> int:
"""Safely derive the policy width from a shared resource mapping."""
policy_tensor = None
try:
policy_tensor = resource.get("response_policy_tensor") # type: ignore[attr-defined]
except AttributeError:
# Resource may not behave like a mapping (should not happen, but guard anyway)
pass

if policy_tensor is None:
return 0

shape = getattr(policy_tensor, "shape", ())
if len(shape) >= 2:
width = shape[1]
elif len(shape) == 1:
width = shape[0]
else:
width = 0

try:
return int(width)
except (TypeError, ValueError):
return 0
Comment on lines +96 to +98
Copy link

Copilot AI Sep 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function returns 0 for invalid width values, but later code uses this as a tensor dimension. A zero-width tensor would cause runtime errors. Consider returning a sensible default width or raising an exception to fail fast.

Copilot uses AI. Check for mistakes.


def run_inference_server(
device: str,
model_cfg: dict,
Expand Down Expand Up @@ -454,9 +480,31 @@ def run_inference_server(
logger.error(f"Batch tensor device: {batch_tensor.device}")
logger.error(f"Model device: {next(model.parameters()).device}")

# Return fallback values
# Return fallback values sized to the shared policy tensor
batch_size = batch_tensor.shape[0]
policy_logits_np = np.zeros((batch_size, 4672), dtype=np.float32)
policy_width = 0
candidate_workers: List[int] = list(batch_sizes.keys())
for worker_id in batch_indices:
if worker_id not in batch_sizes:
candidate_workers.append(worker_id)

for worker_id in candidate_workers:
if 0 <= worker_id < len(shared_memory_resources):
policy_width = _get_policy_width_from_resource(
shared_memory_resources[worker_id]
)
if policy_width > 0:
break

if policy_width <= 0:
for res in shared_memory_resources:
policy_width = _get_policy_width_from_resource(res)
if policy_width > 0:
break

policy_logits_np = np.zeros(
(batch_size, policy_width), dtype=np.float32
)
Comment on lines +505 to +507
Copy link

Copilot AI Sep 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When policy_width is 0 (from the helper function), this creates a tensor with shape (batch_size, 0) which will likely cause errors in downstream code expecting valid policy dimensions.

Copilot uses AI. Check for mistakes.
value = np.zeros((batch_size, 1), dtype=np.float32)

# Distribute results back to workers
Expand Down Expand Up @@ -645,7 +693,12 @@ def infer_np(self, arr_batch: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
def _get_fallback_values(self, batch_size: int) -> Tuple[np.ndarray, np.ndarray]:
"""Return safe fallback values when inference fails completely."""
# Return neutral logits (uniform after softmax) and neutral value
policy = np.zeros((batch_size, 4672), dtype=np.float32)
policy_width = _get_policy_width_from_resource(self.res)
if policy_width <= 0:
self.logger.debug(
"Falling back to zero-width policy logits due to missing shape information"
)
Copy link

Copilot AI Sep 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the server path, creating a zero-width policy tensor when policy_width is 0 will cause runtime errors. The fallback should ensure a valid tensor dimension.

Suggested change
)
)
policy_width = max(1, policy_width)

Copilot uses AI. Check for mistakes.
policy = np.zeros((batch_size, policy_width), dtype=np.float32)
value = np.zeros(batch_size, dtype=np.float32)
return policy, value

Expand Down
Loading