diff --git a/primus/tools/preflight/network/network_probe.py b/primus/tools/preflight/network/network_probe.py index 8e159042a..4865467f3 100644 --- a/primus/tools/preflight/network/network_probe.py +++ b/primus/tools/preflight/network/network_probe.py @@ -32,18 +32,34 @@ def _env_int(name: str, default: int) -> int: def detect_distributed_intent() -> Dict[str, Any]: world_size = _env_int("WORLD_SIZE", 1) + local_world_size = _env_int("LOCAL_WORLD_SIZE", 1) + + slurm_nnodes = _env_get("SLURM_NNODES") slurm_ntasks = _env_get("SLURM_NTASKS") + ompi_size = _env_get("OMPI_COMM_WORLD_SIZE") + ompi_local_size = _env_get("OMPI_COMM_WORLD_LOCAL_SIZE") + slurm_nnodes_i = int(slurm_nnodes) if slurm_nnodes and slurm_nnodes.isdigit() else None slurm_ntasks_i = int(slurm_ntasks) if slurm_ntasks and slurm_ntasks.isdigit() else None ompi_size_i = int(ompi_size) if ompi_size and ompi_size.isdigit() else None + ompi_local_size_i = int(ompi_local_size) if ompi_local_size and ompi_local_size.isdigit() else None is_distributed = ( - bool(world_size and world_size > 1) - or bool(slurm_ntasks_i and slurm_ntasks_i > 1) - or bool(ompi_size_i and ompi_size_i > 1) + world_size > 1 + or (slurm_ntasks_i and slurm_ntasks_i > 1) + or (ompi_size_i and ompi_size_i > 1) ) - network_mode = "multi-node" if is_distributed else "single-node" + + nnodes = 1 + if slurm_nnodes_i is not None and slurm_nnodes_i > 0: + nnodes = slurm_nnodes_i + elif ompi_size_i is not None and ompi_local_size_i is not None and ompi_size_i > 1 and ompi_local_size_i > 0: + nnodes = (ompi_size_i + ompi_local_size_i - 1) // ompi_local_size_i + elif world_size > 1 and local_world_size > 0: + nnodes = (world_size + local_world_size - 1) // local_world_size + + network_mode = "multi-node" if nnodes > 1 else "single-node" return { "is_distributed": is_distributed,