diff --git a/nemo_run/run/ray/slurm.py b/nemo_run/run/ray/slurm.py index b82ecf4b..d631397c 100644 --- a/nemo_run/run/ray/slurm.py +++ b/nemo_run/run/ray/slurm.py @@ -278,6 +278,19 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str: return " ".join(_srun_flags) + def get_command_srun_args() -> str: + if ( + self.executor.run_as_group + and self.executor.heterogeneous + and self.executor.resource_group + and self.executor.resource_group[0].srun_args is not None + ): + command_srun_args = self.executor.resource_group[0].srun_args + else: + command_srun_args = self.executor.srun_args or [] + + return " ".join(shlex.quote(arg) for arg in command_srun_args) + ray_log_prefix = job_details.ray_log_prefix vars_to_fill = { "sbatch_flags": sbatch_flags, @@ -296,6 +309,7 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str: "ray_log_prefix": ray_log_prefix, "heterogeneous": self.executor.heterogeneous, "resource_group": self.executor.resource_group if self.executor.heterogeneous else [], + "command_srun_args": get_command_srun_args(), } if self.command_groups: diff --git a/nemo_run/run/ray/templates/ray.sub.j2 b/nemo_run/run/ray/templates/ray.sub.j2 index 8d510cb2..81a25c2c 100644 --- a/nemo_run/run/ray/templates/ray.sub.j2 +++ b/nemo_run/run/ray/templates/ray.sub.j2 @@ -454,7 +454,7 @@ COMMAND="${COMMAND:-{{ command | default('', true) }}}" COMMAND_WORKDIR={{ command_workdir | default('$CONTAINER_CWD') }} if [[ -n "$COMMAND" ]]; then - srun {% if heterogeneous %}--het-group=0 {% endif %}--no-container-mount-home --gpus=0 --overlap --container-name=ray-head --container-workdir=$COMMAND_WORKDIR --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/{{ ray_log_prefix }}job.log bash -c "$COMMAND" + srun {% if heterogeneous %}--het-group=0 {% endif %}--no-container-mount-home --gpus=0 --overlap {% if command_srun_args %}{{ command_srun_args }} {% endif %}--container-name=ray-head --container-workdir=$COMMAND_WORKDIR --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/{{ ray_log_prefix }}job.log bash -c "$COMMAND" else echo "[INFO]: Ray Cluster is idled, run this on the slurm head node to get a shell to the head node:" cat <$CLUSTER_DIR/scripts/${SLURM_JOB_ID}-attach.sh diff --git a/test/run/ray/test_slurm_ray_request.py b/test/run/ray/test_slurm_ray_request.py index d9a41ae7..652d283b 100644 --- a/test/run/ray/test_slurm_ray_request.py +++ b/test/run/ray/test_slurm_ray_request.py @@ -627,6 +627,70 @@ def test_command_groups_without_resource_group(self): assert "--overlap" in script assert "cmd1" in script # Second command in the list (index 1) + def test_command_srun_honors_executor_srun_args(self): + """Test that the COMMAND launch srun includes executor srun_args.""" + executor = SlurmExecutor(account="test_account", srun_args=["--mpi=pmix"]) + executor.tunnel = Mock(spec=SSHTunnel) + executor.tunnel.job_dir = "/tmp/test_jobs" + + request = SlurmRayRequest( + name="test-ray-cluster", + cluster_dir="/tmp/test_jobs/test-ray-cluster", + template_name="ray.sub.j2", + executor=executor, + command="echo hello", + launch_cmd=["sbatch", "--parsable"], + ) + + script = request.materialize() + assert "--gpus=0 --overlap --mpi=pmix --container-name=ray-head" in script + + def test_command_srun_honors_head_resource_group_srun_args(self): + """Test that heterogeneous grouped runs use head resource-group srun_args for COMMAND.""" + executor = SlurmExecutor( + account="test_account", + heterogeneous=True, + srun_args=["--mpi=none"], + ) + executor.run_as_group = True + executor.resource_group = [ + SlurmExecutor.ResourceRequest( + packager=Mock(), + nodes=1, + ntasks_per_node=1, + container_image="image1", + container_mounts=["/data:/data"], + srun_args=["--mpi=pmix"], + het_group_index=0, + ), + SlurmExecutor.ResourceRequest( + packager=Mock(), + nodes=1, + ntasks_per_node=1, + container_image="image2", + container_mounts=["/data:/data"], + het_group_index=1, + ), + ] + executor.tunnel = Mock(spec=SSHTunnel) + executor.tunnel.job_dir = "/tmp/test_jobs" + + request = SlurmRayRequest( + name="test-ray-cluster", + cluster_dir="/tmp/test_jobs/test-ray-cluster", + template_name="ray.sub.j2", + executor=executor, + command="echo hello", + launch_cmd=["sbatch", "--parsable"], + ) + + script = request.materialize() + assert ( + "--het-group=0 --no-container-mount-home --gpus=0 --overlap --mpi=pmix " + "--container-name=ray-head" in script + ) + assert "--gpus=0 --overlap --mpi=none --container-name=ray-head" not in script + def test_env_vars_formatting(self): """Test that environment variables are properly formatted as export statements.""" executor = SlurmExecutor(