From d2f9003fdb16518a2509e423d2d3723b27d2c476 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Fri, 16 Jan 2026 15:02:25 -0800 Subject: [PATCH 1/4] feat: support container-image None in slurm Signed-off-by: Hemil Desai --- nemo_run/core/execution/slurm.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/nemo_run/core/execution/slurm.py b/nemo_run/core/execution/slurm.py index 2b52d292..c273c2c8 100644 --- a/nemo_run/core/execution/slurm.py +++ b/nemo_run/core/execution/slurm.py @@ -369,12 +369,12 @@ def merge( main_executor.run_as_group = True if main_executor.het_group_indices: - assert main_executor.heterogeneous, ( - "heterogeneous must be True if het_group_indices is provided" - ) - assert len(main_executor.het_group_indices) == num_tasks, ( - "het_group_indices must be the same length as the number of tasks" - ) + assert ( + main_executor.heterogeneous + ), "heterogeneous must be True if het_group_indices is provided" + assert ( + len(main_executor.het_group_indices) == num_tasks + ), "het_group_indices must be the same length as the number of tasks" assert all( x <= y for x, y in zip( @@ -858,9 +858,9 @@ def materialize(self) -> str: sbatch_flags = [] if self.executor.heterogeneous: - assert len(self.jobs) == len(self.executor.resource_group), ( - f"Number of jobs {len(self.jobs)} must match number of resource group requests {len(self.executor.resource_group)}.\nIf you are just submitting a single job, make sure that heterogeneous=False in the executor." - ) + assert ( + len(self.jobs) == len(self.executor.resource_group) + ), f"Number of jobs {len(self.jobs)} must match number of resource group requests {len(self.executor.resource_group)}.\nIf you are just submitting a single job, make sure that heterogeneous=False in the executor." final_group_index = len(self.executor.resource_group) - 1 if self.executor.het_group_indices: final_group_index = self.executor.het_group_indices.index( @@ -870,9 +870,9 @@ def materialize(self) -> str: for i in range(len(self.executor.resource_group)): resource_req = self.executor.resource_group[i] if resource_req.het_group_index is not None: - assert self.executor.resource_group[i - 1].het_group_index is not None, ( - "het_group_index must be set for all requests in resource_group" - ) + assert ( + self.executor.resource_group[i - 1].het_group_index is not None + ), "het_group_index must be set for all requests in resource_group" if ( i > 0 and resource_req.het_group_index From d97e5942914ccdb5e6f826c933e86a7944b7759e Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Fri, 16 Jan 2026 15:27:31 -0800 Subject: [PATCH 2/4] fix Signed-off-by: Hemil Desai --- nemo_run/core/execution/slurm.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/nemo_run/core/execution/slurm.py b/nemo_run/core/execution/slurm.py index c273c2c8..2b52d292 100644 --- a/nemo_run/core/execution/slurm.py +++ b/nemo_run/core/execution/slurm.py @@ -369,12 +369,12 @@ def merge( main_executor.run_as_group = True if main_executor.het_group_indices: - assert ( - main_executor.heterogeneous - ), "heterogeneous must be True if het_group_indices is provided" - assert ( - len(main_executor.het_group_indices) == num_tasks - ), "het_group_indices must be the same length as the number of tasks" + assert main_executor.heterogeneous, ( + "heterogeneous must be True if het_group_indices is provided" + ) + assert len(main_executor.het_group_indices) == num_tasks, ( + "het_group_indices must be the same length as the number of tasks" + ) assert all( x <= y for x, y in zip( @@ -858,9 +858,9 @@ def materialize(self) -> str: sbatch_flags = [] if self.executor.heterogeneous: - assert ( - len(self.jobs) == len(self.executor.resource_group) - ), f"Number of jobs {len(self.jobs)} must match number of resource group requests {len(self.executor.resource_group)}.\nIf you are just submitting a single job, make sure that heterogeneous=False in the executor." + assert len(self.jobs) == len(self.executor.resource_group), ( + f"Number of jobs {len(self.jobs)} must match number of resource group requests {len(self.executor.resource_group)}.\nIf you are just submitting a single job, make sure that heterogeneous=False in the executor." + ) final_group_index = len(self.executor.resource_group) - 1 if self.executor.het_group_indices: final_group_index = self.executor.het_group_indices.index( @@ -870,9 +870,9 @@ def materialize(self) -> str: for i in range(len(self.executor.resource_group)): resource_req = self.executor.resource_group[i] if resource_req.het_group_index is not None: - assert ( - self.executor.resource_group[i - 1].het_group_index is not None - ), "het_group_index must be set for all requests in resource_group" + assert self.executor.resource_group[i - 1].het_group_index is not None, ( + "het_group_index must be set for all requests in resource_group" + ) if ( i > 0 and resource_req.het_group_index From 2af4f5ee800d6665deece2f4f0efed89d1822d7d Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Mon, 2 Mar 2026 16:55:10 -0800 Subject: [PATCH 3/4] Honor executor srun_args for Ray command srun Signed-off-by: Hemil Desai --- nemo_run/run/ray/slurm.py | 26 ++++++++++++++++++++------ nemo_run/run/ray/templates/ray.sub.j2 | 2 +- test/run/ray/test_slurm_ray_request.py | 18 ++++++++++++++++++ 3 files changed, 39 insertions(+), 7 deletions(-) diff --git a/nemo_run/run/ray/slurm.py b/nemo_run/run/ray/slurm.py index b82ecf4b..643d5f49 100644 --- a/nemo_run/run/ray/slurm.py +++ b/nemo_run/run/ray/slurm.py @@ -278,6 +278,19 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str: return " ".join(_srun_flags) + def get_command_srun_args() -> str: + if ( + self.executor.run_as_group + and self.executor.heterogeneous + and self.executor.resource_group + and self.executor.resource_group[0].srun_args is not None + ): + command_srun_args = self.executor.resource_group[0].srun_args + else: + command_srun_args = self.executor.srun_args or [] + + return " ".join(shlex.quote(arg) for arg in command_srun_args) + ray_log_prefix = job_details.ray_log_prefix vars_to_fill = { "sbatch_flags": sbatch_flags, @@ -296,6 +309,7 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str: "ray_log_prefix": ray_log_prefix, "heterogeneous": self.executor.heterogeneous, "resource_group": self.executor.resource_group if self.executor.heterogeneous else [], + "command_srun_args": get_command_srun_args(), } if self.command_groups: @@ -1257,9 +1271,9 @@ def start( if isinstance(self.executor.tunnel, SSHTunnel): # Rsync workdir honouring .gitignore self.executor.tunnel.connect() - assert self.executor.tunnel.session is not None, ( - "Tunnel session is not connected" - ) + assert ( + self.executor.tunnel.session is not None + ), "Tunnel session is not connected" rsync( self.executor.tunnel.session, workdir, @@ -1314,9 +1328,9 @@ def start( if isinstance(self.executor.tunnel, SSHTunnel): self.executor.tunnel.connect() - assert self.executor.tunnel.session is not None, ( - "Tunnel session is not connected" - ) + assert ( + self.executor.tunnel.session is not None + ), "Tunnel session is not connected" rsync( self.executor.tunnel.session, os.path.join(local_code_extraction_path, ""), diff --git a/nemo_run/run/ray/templates/ray.sub.j2 b/nemo_run/run/ray/templates/ray.sub.j2 index 8d510cb2..81a25c2c 100644 --- a/nemo_run/run/ray/templates/ray.sub.j2 +++ b/nemo_run/run/ray/templates/ray.sub.j2 @@ -454,7 +454,7 @@ COMMAND="${COMMAND:-{{ command | default('', true) }}}" COMMAND_WORKDIR={{ command_workdir | default('$CONTAINER_CWD') }} if [[ -n "$COMMAND" ]]; then - srun {% if heterogeneous %}--het-group=0 {% endif %}--no-container-mount-home --gpus=0 --overlap --container-name=ray-head --container-workdir=$COMMAND_WORKDIR --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/{{ ray_log_prefix }}job.log bash -c "$COMMAND" + srun {% if heterogeneous %}--het-group=0 {% endif %}--no-container-mount-home --gpus=0 --overlap {% if command_srun_args %}{{ command_srun_args }} {% endif %}--container-name=ray-head --container-workdir=$COMMAND_WORKDIR --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/{{ ray_log_prefix }}job.log bash -c "$COMMAND" else echo "[INFO]: Ray Cluster is idled, run this on the slurm head node to get a shell to the head node:" cat <$CLUSTER_DIR/scripts/${SLURM_JOB_ID}-attach.sh diff --git a/test/run/ray/test_slurm_ray_request.py b/test/run/ray/test_slurm_ray_request.py index d9a41ae7..373238f3 100644 --- a/test/run/ray/test_slurm_ray_request.py +++ b/test/run/ray/test_slurm_ray_request.py @@ -627,6 +627,24 @@ def test_command_groups_without_resource_group(self): assert "--overlap" in script assert "cmd1" in script # Second command in the list (index 1) + def test_command_srun_honors_executor_srun_args(self): + """Test that the COMMAND launch srun includes executor srun_args.""" + executor = SlurmExecutor(account="test_account", srun_args=["--mpi=pmix"]) + executor.tunnel = Mock(spec=SSHTunnel) + executor.tunnel.job_dir = "/tmp/test_jobs" + + request = SlurmRayRequest( + name="test-ray-cluster", + cluster_dir="/tmp/test_jobs/test-ray-cluster", + template_name="ray.sub.j2", + executor=executor, + command="echo hello", + launch_cmd=["sbatch", "--parsable"], + ) + + script = request.materialize() + assert "--gpus=0 --overlap --mpi=pmix --container-name=ray-head" in script + def test_env_vars_formatting(self): """Test that environment variables are properly formatted as export statements.""" executor = SlurmExecutor( From 911edd925b3eb9c03da5f458439d05eeb4e8af9d Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Mon, 2 Mar 2026 23:00:03 -0800 Subject: [PATCH 4/4] Fix ray srun formatting and cover heterogeneous command args Signed-off-by: Hemil Desai --- nemo_run/run/ray/slurm.py | 12 +++---- test/run/ray/test_slurm_ray_request.py | 46 ++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 6 deletions(-) diff --git a/nemo_run/run/ray/slurm.py b/nemo_run/run/ray/slurm.py index 643d5f49..d631397c 100644 --- a/nemo_run/run/ray/slurm.py +++ b/nemo_run/run/ray/slurm.py @@ -1271,9 +1271,9 @@ def start( if isinstance(self.executor.tunnel, SSHTunnel): # Rsync workdir honouring .gitignore self.executor.tunnel.connect() - assert ( - self.executor.tunnel.session is not None - ), "Tunnel session is not connected" + assert self.executor.tunnel.session is not None, ( + "Tunnel session is not connected" + ) rsync( self.executor.tunnel.session, workdir, @@ -1328,9 +1328,9 @@ def start( if isinstance(self.executor.tunnel, SSHTunnel): self.executor.tunnel.connect() - assert ( - self.executor.tunnel.session is not None - ), "Tunnel session is not connected" + assert self.executor.tunnel.session is not None, ( + "Tunnel session is not connected" + ) rsync( self.executor.tunnel.session, os.path.join(local_code_extraction_path, ""), diff --git a/test/run/ray/test_slurm_ray_request.py b/test/run/ray/test_slurm_ray_request.py index 373238f3..652d283b 100644 --- a/test/run/ray/test_slurm_ray_request.py +++ b/test/run/ray/test_slurm_ray_request.py @@ -645,6 +645,52 @@ def test_command_srun_honors_executor_srun_args(self): script = request.materialize() assert "--gpus=0 --overlap --mpi=pmix --container-name=ray-head" in script + def test_command_srun_honors_head_resource_group_srun_args(self): + """Test that heterogeneous grouped runs use head resource-group srun_args for COMMAND.""" + executor = SlurmExecutor( + account="test_account", + heterogeneous=True, + srun_args=["--mpi=none"], + ) + executor.run_as_group = True + executor.resource_group = [ + SlurmExecutor.ResourceRequest( + packager=Mock(), + nodes=1, + ntasks_per_node=1, + container_image="image1", + container_mounts=["/data:/data"], + srun_args=["--mpi=pmix"], + het_group_index=0, + ), + SlurmExecutor.ResourceRequest( + packager=Mock(), + nodes=1, + ntasks_per_node=1, + container_image="image2", + container_mounts=["/data:/data"], + het_group_index=1, + ), + ] + executor.tunnel = Mock(spec=SSHTunnel) + executor.tunnel.job_dir = "/tmp/test_jobs" + + request = SlurmRayRequest( + name="test-ray-cluster", + cluster_dir="/tmp/test_jobs/test-ray-cluster", + template_name="ray.sub.j2", + executor=executor, + command="echo hello", + launch_cmd=["sbatch", "--parsable"], + ) + + script = request.materialize() + assert ( + "--het-group=0 --no-container-mount-home --gpus=0 --overlap --mpi=pmix " + "--container-name=ray-head" in script + ) + assert "--gpus=0 --overlap --mpi=none --container-name=ray-head" not in script + def test_env_vars_formatting(self): """Test that environment variables are properly formatted as export statements.""" executor = SlurmExecutor(