Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions primus/tools/preflight/network/network_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,34 @@ def _env_int(name: str, default: int) -> int:

def detect_distributed_intent() -> Dict[str, Any]:
world_size = _env_int("WORLD_SIZE", 1)
local_world_size = _env_int("LOCAL_WORLD_SIZE", 1)

slurm_nnodes = _env_get("SLURM_NNODES")
slurm_ntasks = _env_get("SLURM_NTASKS")

ompi_size = _env_get("OMPI_COMM_WORLD_SIZE")
ompi_local_size = _env_get("OMPI_COMM_WORLD_LOCAL_SIZE")

slurm_nnodes_i = int(slurm_nnodes) if slurm_nnodes and slurm_nnodes.isdigit() else None
slurm_ntasks_i = int(slurm_ntasks) if slurm_ntasks and slurm_ntasks.isdigit() else None
ompi_size_i = int(ompi_size) if ompi_size and ompi_size.isdigit() else None
ompi_local_size_i = int(ompi_local_size) if ompi_local_size and ompi_local_size.isdigit() else None

is_distributed = (
bool(world_size and world_size > 1)
or bool(slurm_ntasks_i and slurm_ntasks_i > 1)
or bool(ompi_size_i and ompi_size_i > 1)
world_size > 1
or (slurm_ntasks_i and slurm_ntasks_i > 1)
or (ompi_size_i and ompi_size_i > 1)
)
network_mode = "multi-node" if is_distributed else "single-node"

nnodes = 1
if slurm_nnodes_i is not None and slurm_nnodes_i > 0:
nnodes = slurm_nnodes_i
elif ompi_size_i is not None and ompi_local_size_i is not None and ompi_size_i > 1 and ompi_local_size_i > 0:
nnodes = (ompi_size_i + ompi_local_size_i - 1) // ompi_local_size_i
elif world_size > 1 and local_world_size > 0:
nnodes = (world_size + local_world_size - 1) // local_world_size

network_mode = "multi-node" if nnodes > 1 else "single-node"

return {
"is_distributed": is_distributed,
Expand Down
Loading