diff --git a/azchess/selfplay/inference.py b/azchess/selfplay/inference.py index 09aeea9..aa5e2cc 100644 --- a/azchess/selfplay/inference.py +++ b/azchess/selfplay/inference.py @@ -72,6 +72,32 @@ def _recreate_worker_events(shared_memory_resources: List[Dict[str, Any]], attem logger.error(f"Failed to recreate events for worker {i}: {recreate_error}") +def _get_policy_width_from_resource(resource: Dict[str, Any]) -> int: + """Safely derive the policy width from a shared resource mapping.""" + policy_tensor = None + try: + policy_tensor = resource.get("response_policy_tensor") # type: ignore[attr-defined] + except AttributeError: + # Resource may not behave like a mapping (should not happen, but guard anyway) + pass + + if policy_tensor is None: + return 0 + + shape = getattr(policy_tensor, "shape", ()) + if len(shape) >= 2: + width = shape[1] + elif len(shape) == 1: + width = shape[0] + else: + width = 0 + + try: + return int(width) + except (TypeError, ValueError): + return 0 + + def run_inference_server( device: str, model_cfg: dict, @@ -454,9 +480,31 @@ def run_inference_server( logger.error(f"Batch tensor device: {batch_tensor.device}") logger.error(f"Model device: {next(model.parameters()).device}") - # Return fallback values + # Return fallback values sized to the shared policy tensor batch_size = batch_tensor.shape[0] - policy_logits_np = np.zeros((batch_size, 4672), dtype=np.float32) + policy_width = 0 + candidate_workers: List[int] = list(batch_sizes.keys()) + for worker_id in batch_indices: + if worker_id not in batch_sizes: + candidate_workers.append(worker_id) + + for worker_id in candidate_workers: + if 0 <= worker_id < len(shared_memory_resources): + policy_width = _get_policy_width_from_resource( + shared_memory_resources[worker_id] + ) + if policy_width > 0: + break + + if policy_width <= 0: + for res in shared_memory_resources: + policy_width = _get_policy_width_from_resource(res) + if policy_width > 0: + break + + policy_logits_np = np.zeros( + (batch_size, policy_width), dtype=np.float32 + ) value = np.zeros((batch_size, 1), dtype=np.float32) # Distribute results back to workers @@ -645,7 +693,12 @@ def infer_np(self, arr_batch: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: def _get_fallback_values(self, batch_size: int) -> Tuple[np.ndarray, np.ndarray]: """Return safe fallback values when inference fails completely.""" # Return neutral logits (uniform after softmax) and neutral value - policy = np.zeros((batch_size, 4672), dtype=np.float32) + policy_width = _get_policy_width_from_resource(self.res) + if policy_width <= 0: + self.logger.debug( + "Falling back to zero-width policy logits due to missing shape information" + ) + policy = np.zeros((batch_size, policy_width), dtype=np.float32) value = np.zeros(batch_size, dtype=np.float32) return policy, value