-
Notifications
You must be signed in to change notification settings - Fork 0
Ensure inference fallback respects dynamic policy size #95
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -72,6 +72,32 @@ def _recreate_worker_events(shared_memory_resources: List[Dict[str, Any]], attem | |||||||
| logger.error(f"Failed to recreate events for worker {i}: {recreate_error}") | ||||||||
|
|
||||||||
|
|
||||||||
| def _get_policy_width_from_resource(resource: Dict[str, Any]) -> int: | ||||||||
| """Safely derive the policy width from a shared resource mapping.""" | ||||||||
| policy_tensor = None | ||||||||
| try: | ||||||||
| policy_tensor = resource.get("response_policy_tensor") # type: ignore[attr-defined] | ||||||||
| except AttributeError: | ||||||||
| # Resource may not behave like a mapping (should not happen, but guard anyway) | ||||||||
| pass | ||||||||
|
|
||||||||
| if policy_tensor is None: | ||||||||
| return 0 | ||||||||
|
|
||||||||
| shape = getattr(policy_tensor, "shape", ()) | ||||||||
| if len(shape) >= 2: | ||||||||
| width = shape[1] | ||||||||
| elif len(shape) == 1: | ||||||||
| width = shape[0] | ||||||||
| else: | ||||||||
| width = 0 | ||||||||
|
|
||||||||
| try: | ||||||||
| return int(width) | ||||||||
| except (TypeError, ValueError): | ||||||||
| return 0 | ||||||||
|
|
||||||||
|
|
||||||||
| def run_inference_server( | ||||||||
| device: str, | ||||||||
| model_cfg: dict, | ||||||||
|
|
@@ -454,9 +480,31 @@ def run_inference_server( | |||||||
| logger.error(f"Batch tensor device: {batch_tensor.device}") | ||||||||
| logger.error(f"Model device: {next(model.parameters()).device}") | ||||||||
|
|
||||||||
| # Return fallback values | ||||||||
| # Return fallback values sized to the shared policy tensor | ||||||||
| batch_size = batch_tensor.shape[0] | ||||||||
| policy_logits_np = np.zeros((batch_size, 4672), dtype=np.float32) | ||||||||
| policy_width = 0 | ||||||||
| candidate_workers: List[int] = list(batch_sizes.keys()) | ||||||||
| for worker_id in batch_indices: | ||||||||
| if worker_id not in batch_sizes: | ||||||||
| candidate_workers.append(worker_id) | ||||||||
|
|
||||||||
| for worker_id in candidate_workers: | ||||||||
| if 0 <= worker_id < len(shared_memory_resources): | ||||||||
| policy_width = _get_policy_width_from_resource( | ||||||||
| shared_memory_resources[worker_id] | ||||||||
| ) | ||||||||
| if policy_width > 0: | ||||||||
| break | ||||||||
|
|
||||||||
| if policy_width <= 0: | ||||||||
| for res in shared_memory_resources: | ||||||||
| policy_width = _get_policy_width_from_resource(res) | ||||||||
| if policy_width > 0: | ||||||||
| break | ||||||||
|
|
||||||||
| policy_logits_np = np.zeros( | ||||||||
| (batch_size, policy_width), dtype=np.float32 | ||||||||
| ) | ||||||||
|
Comment on lines
+505
to
+507
|
||||||||
| value = np.zeros((batch_size, 1), dtype=np.float32) | ||||||||
|
|
||||||||
| # Distribute results back to workers | ||||||||
|
|
@@ -645,7 +693,12 @@ def infer_np(self, arr_batch: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: | |||||||
| def _get_fallback_values(self, batch_size: int) -> Tuple[np.ndarray, np.ndarray]: | ||||||||
| """Return safe fallback values when inference fails completely.""" | ||||||||
| # Return neutral logits (uniform after softmax) and neutral value | ||||||||
| policy = np.zeros((batch_size, 4672), dtype=np.float32) | ||||||||
| policy_width = _get_policy_width_from_resource(self.res) | ||||||||
| if policy_width <= 0: | ||||||||
| self.logger.debug( | ||||||||
| "Falling back to zero-width policy logits due to missing shape information" | ||||||||
| ) | ||||||||
|
||||||||
| ) | |
| ) | |
| policy_width = max(1, policy_width) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function returns 0 for invalid width values, but later code uses this as a tensor dimension. A zero-width tensor would cause runtime errors. Consider returning a sensible default width or raising an exception to fail fast.