diff --git a/src/ares/managers/utils/placement_strategy.py b/src/ares/managers/utils/placement_strategy.py index 3cbf66f..1b4ebf2 100644 --- a/src/ares/managers/utils/placement_strategy.py +++ b/src/ares/managers/utils/placement_strategy.py @@ -1,6 +1,7 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable +import numpy as np from cython_extensions import cy_distance_to_squared from loguru import logger from sc2.ids.unit_typeid import UnitTypeId as UnitID @@ -52,9 +53,21 @@ def __init__( self.building_size = building_size def _filter_by_flag( - self, flag_name: str, available, placements_for_base: dict[Point2, dict] + self, + flag_name: str, + available: list[Point2], + placements_for_base: dict[Point2, dict], ) -> list[Point2]: - return [a for a in available if placements_for_base[a].get(flag_name, False)] + grid: np.ndarray = self.placement_manager.manager_mediator.get_ground_grid + safety_check: Callable = ( + self.placement_manager.manager_mediator.is_position_safe + ) + return [ + a + for a in available + if safety_check(grid=grid, position=a) + and placements_for_base[a].get(flag_name, False) + ] class PoweredPlacementStrategy(BasePlacementStrategy):