Skip to content

cannot request --mem-per-gpu as a job_extra_directives #698

@jlmeunier

Description

@jlmeunier

Describe the issue:
I want to specify a certain GPU memory on CUDA, with a SLURMCluster instance.

But if I pass a memory value and job_extra_directives=["--mem-per-gpu=40G"]
I get a message:
sbatch: fatal: --mem, --mem-per-cpu, and --mem-per-gpu are mutually exclusive.

My concern is that SLURMCluster refuses a call without memory=SOME_VALUE

So I cannot specify the amount of GPU memory I need.

Unless I miss something, in which case, sorry!

JL

Minimal Complete Verifiable Example:

import os
from dask.distributed import Client
from dask.distributed import WorkerPlugin
from dask.distributed import LocalCluster
#from dask.distributed import print # need to configure logging to see print output
from dask.distributed import warn
from dask_jobqueue    import SLURMCluster

import torch

class MyWorker(WorkerPlugin):
    def setup(self, worker):
        # This runs once per worker process
        if torch.cuda.is_available():
            device = 'cuda'
        elif torch.mps.is_available():
            device = 'mps'
        else:
            device = 'cpu'
        self.device = device
        self.a = torch.arange(4, dtype=torch.float32).reshape(2,2).to(self.device)    

        warn(
            f"Worker ready | "
            f"address={worker.address} | "
            f"id={worker.id} | "
            f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')} | "
            f"device={self.device}"
        )

    def multiply(self, i):
        b = torch.tensor([i]*2, dtype=torch.float32)
        b = b.to(self.device)
        c = torch.matmul(self.a, b)
        c = torch.matmul(c, b)
        c = float(c.unsqueeze(0))
        return c
    
    def teardown(self, worker):
        del self.a


def do_multiply(i):
    """Function called on each of our data instance
    """
    from dask.distributed import get_worker
    worker = get_worker()
    plugin = worker.plugins["matmul"]
    return plugin.multiply(i)


if __name__ == '__main__':
    if False:
        cluster = LocalCluster(
                        n_workers=2,
                        threads_per_worker=1
                )
    else: # --- SLURM ---
        if False:
            #  *** NOT SPECIFYING memory
            # ValueError: You must specify how much cores and memory per job you want to use, for example:
            # cluster = SLURMCluster(cores=1, memory='24GB')
            cluster = SLURMCluster(
                queue='gpu-be',
                cores=1,
                #memory='16GB',  # *** PB HERE ***
                processes=1,
                walltime='02:00:00',
                job_extra_directives=['--gres=gpu:1']
            )
            cluster.scale(1)  # Launch 1 worker
        elif False:
            #  *** SPECIFYING memory and extra directives
            # sbatch: fatal: --mem, --mem-per-cpu, and --mem-per-gpu are mutually exclusive.
            cluster = SLURMCluster(
                queue='gpu-be',
                cores=1,
                memory='16GB',
                processes=1,
                walltime='02:00:00',
                job_extra_directives=['--gres=gpu:1', '--mem-per-gpu=40G']
            )
        else:
            # Only way that works. How to specifyf memory per gpu??
            cluster = SLURMCluster(
                    queue='gpu-be',
                    cores=1,
                    memory='16GB',
                    processes=1,
                    walltime='02:00:00',
                    job_extra_directives=['--gres=gpu:1']
                )
        cluster.scale(1)  # Launch 1 worker
        print(cluster.job_script())
        
    client = Client(cluster) 
    print(client.dashboard_link)

    client.register_plugin(MyWorker(), name="matmul")

    futures = client.map(do_multiply,
                         [0,1,2,3],
                         retries=0
                         )
    results = client.gather(futures)
    print(results)
    
    client.close()
    cluster.close()

Anything else we need to know?:

Environment:

  • Dask version: dask 2025.12.0 dask-jobqueue 0.9.0
  • Python version: 3.12
  • Operating System: Linux
  • Install method (conda, pip, source): pip in a conda env

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions