Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions nemo_run/run/ray/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,19 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str:

return " ".join(_srun_flags)

def get_command_srun_args() -> str:
if (
self.executor.run_as_group
and self.executor.heterogeneous
and self.executor.resource_group
and self.executor.resource_group[0].srun_args is not None
):
command_srun_args = self.executor.resource_group[0].srun_args
else:
command_srun_args = self.executor.srun_args or []

return " ".join(shlex.quote(arg) for arg in command_srun_args)

ray_log_prefix = job_details.ray_log_prefix
vars_to_fill = {
"sbatch_flags": sbatch_flags,
Expand All @@ -296,6 +309,7 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str:
"ray_log_prefix": ray_log_prefix,
"heterogeneous": self.executor.heterogeneous,
"resource_group": self.executor.resource_group if self.executor.heterogeneous else [],
"command_srun_args": get_command_srun_args(),
}

if self.command_groups:
Expand Down
2 changes: 1 addition & 1 deletion nemo_run/run/ray/templates/ray.sub.j2
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ COMMAND="${COMMAND:-{{ command | default('', true) }}}"
COMMAND_WORKDIR={{ command_workdir | default('$CONTAINER_CWD') }}

if [[ -n "$COMMAND" ]]; then
srun {% if heterogeneous %}--het-group=0 {% endif %}--no-container-mount-home --gpus=0 --overlap --container-name=ray-head --container-workdir=$COMMAND_WORKDIR --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/{{ ray_log_prefix }}job.log bash -c "$COMMAND"
srun {% if heterogeneous %}--het-group=0 {% endif %}--no-container-mount-home --gpus=0 --overlap {% if command_srun_args %}{{ command_srun_args }} {% endif %}--container-name=ray-head --container-workdir=$COMMAND_WORKDIR --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/{{ ray_log_prefix }}job.log bash -c "$COMMAND"
else
echo "[INFO]: Ray Cluster is idled, run this on the slurm head node to get a shell to the head node:"
cat <<EOF >$CLUSTER_DIR/scripts/${SLURM_JOB_ID}-attach.sh
Expand Down
64 changes: 64 additions & 0 deletions test/run/ray/test_slurm_ray_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,70 @@ def test_command_groups_without_resource_group(self):
assert "--overlap" in script
assert "cmd1" in script # Second command in the list (index 1)

def test_command_srun_honors_executor_srun_args(self):
"""Test that the COMMAND launch srun includes executor srun_args."""
executor = SlurmExecutor(account="test_account", srun_args=["--mpi=pmix"])
executor.tunnel = Mock(spec=SSHTunnel)
executor.tunnel.job_dir = "/tmp/test_jobs"

request = SlurmRayRequest(
name="test-ray-cluster",
cluster_dir="/tmp/test_jobs/test-ray-cluster",
template_name="ray.sub.j2",
executor=executor,
command="echo hello",
launch_cmd=["sbatch", "--parsable"],
)

script = request.materialize()
assert "--gpus=0 --overlap --mpi=pmix --container-name=ray-head" in script

def test_command_srun_honors_head_resource_group_srun_args(self):
"""Test that heterogeneous grouped runs use head resource-group srun_args for COMMAND."""
executor = SlurmExecutor(
account="test_account",
heterogeneous=True,
srun_args=["--mpi=none"],
)
executor.run_as_group = True
executor.resource_group = [
SlurmExecutor.ResourceRequest(
packager=Mock(),
nodes=1,
ntasks_per_node=1,
container_image="image1",
container_mounts=["/data:/data"],
srun_args=["--mpi=pmix"],
het_group_index=0,
),
SlurmExecutor.ResourceRequest(
packager=Mock(),
nodes=1,
ntasks_per_node=1,
container_image="image2",
container_mounts=["/data:/data"],
het_group_index=1,
),
]
executor.tunnel = Mock(spec=SSHTunnel)
executor.tunnel.job_dir = "/tmp/test_jobs"

request = SlurmRayRequest(
name="test-ray-cluster",
cluster_dir="/tmp/test_jobs/test-ray-cluster",
template_name="ray.sub.j2",
executor=executor,
command="echo hello",
launch_cmd=["sbatch", "--parsable"],
)

script = request.materialize()
assert (
"--het-group=0 --no-container-mount-home --gpus=0 --overlap --mpi=pmix "
"--container-name=ray-head" in script
)
assert "--gpus=0 --overlap --mpi=none --container-name=ray-head" not in script

def test_env_vars_formatting(self):
"""Test that environment variables are properly formatted as export statements."""
executor = SlurmExecutor(
Expand Down
Loading