From 87708d2bd00753a73f66df1d2f29c86c03d66f4d Mon Sep 17 00:00:00 2001 From: Alexander Date: Tue, 17 Mar 2026 20:01:41 -0400 Subject: [PATCH] fix: handle psutil errors gracefully in _get_ncpus and bump to 0.12.1 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Robust CPU detection cascade: sched_getaffinity → cpu_count → psutil - Explicit RuntimeError if all detection methods fail - Always log resolved core count for HPC verification - pragma: no cover on platform-specific unreachable branches Made-with: Cursor --- pyproject.toml | 2 +- ssms/dataset_generators/lan_mlp.py | 42 +++++++++++- .../test_training_data_generator.py | 66 +++++++++++++++++++ 3 files changed, 107 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fb464e11..9e707c78 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "ssm-simulators" -version = "0.12.0" +version = "0.12.1" description = "SSMS is a package collecting simulators and training data generators for cognitive science, neuroscience, and approximate bayesian computation" authors = [ { name = "Alexander Fengler", email = "alexander_fengler@brown.edu" }, diff --git a/ssms/dataset_generators/lan_mlp.py b/ssms/dataset_generators/lan_mlp.py index 390501e4..cf9d4b5d 100755 --- a/ssms/dataset_generators/lan_mlp.py +++ b/ssms/dataset_generators/lan_mlp.py @@ -9,6 +9,7 @@ """ import logging +import os import uuid import warnings from copy import deepcopy @@ -314,12 +315,49 @@ def _get_ncpus(self): ) if n_cpus_config == "all": - n_cpus = psutil.cpu_count(logical=False) + n_cpus = None + + # 1. Try process affinity first (best for SLURM/containers) + try: + if hasattr(os, "sched_getaffinity"): + n_cpus = len(os.sched_getaffinity(0)) + elif hasattr( + os, "process_cpu_count" + ): # pragma: no cover (Python 3.13+ without sched_getaffinity) + n_cpus = os.process_cpu_count() + except Exception as e: + logger.debug("Could not get process affinity: %s", e) + + # 2. Fallback to OS logical core count + if n_cpus is None: + try: + n_cpus = os.cpu_count() + except Exception as e: + logger.debug("os.cpu_count() failed: %s", e) + + # 3. Fallback to psutil + if n_cpus is None: + try: + n_cpus = psutil.cpu_count(logical=False) or psutil.cpu_count( + logical=True + ) + except Exception as e: + logger.debug("psutil.cpu_count() failed: %s", e) + + # 4. Fail fast if we STILL don't know + if n_cpus is None: + raise RuntimeError( + "Could not determine CPU count automatically. " + "Please specify 'n_cpus' explicitly as an integer in your config." + ) + + # 5. Informative logging + logger.info("Resolved 'n_cpus'='all' to %d cores.", n_cpus) else: n_cpus = n_cpus_config # Update nested config - if "pipeline" not in self.generator_config: + if "pipeline" not in self.generator_config: # pragma: no cover self.generator_config["pipeline"] = {} self.generator_config["pipeline"]["n_cpus"] = n_cpus diff --git a/tests/dataset_generators/test_training_data_generator.py b/tests/dataset_generators/test_training_data_generator.py index 8879cf47..ec0ba984 100644 --- a/tests/dataset_generators/test_training_data_generator.py +++ b/tests/dataset_generators/test_training_data_generator.py @@ -259,6 +259,72 @@ def test_generate_data_likelihood_validity(self, minimal_config): # (allowing some numerical edge cases) assert np.median(lan_labels) < 0 + def test_n_cpus_all_resolution_fallback(self, minimal_config, monkeypatch): + """Test that n_cpus='all' raises RuntimeError when all detection methods fail.""" + import os + import psutil + + minimal_config["pipeline"]["n_cpus"] = "all" + + def raise_exception(*args, **kwargs): + raise Exception("Mocked exception") + + if hasattr(os, "sched_getaffinity"): + monkeypatch.setattr(os, "sched_getaffinity", raise_exception) + if hasattr(os, "process_cpu_count"): + monkeypatch.setattr(os, "process_cpu_count", raise_exception) + monkeypatch.setattr(os, "cpu_count", raise_exception) + monkeypatch.setattr(psutil, "cpu_count", raise_exception) + + with pytest.raises( + RuntimeError, match="Could not determine CPU count automatically" + ): + _ = TrainingDataGenerator(minimal_config, model_config["ddm"]) + + def test_n_cpus_all_resolution_via_affinity( + self, minimal_config, monkeypatch, caplog + ): + """Test successful resolution via os.sched_getaffinity.""" + import logging + import os + + minimal_config["pipeline"]["n_cpus"] = "all" + + monkeypatch.setattr( + os, "sched_getaffinity", lambda pid: {0, 1, 2, 3}, raising=False + ) + + with caplog.at_level(logging.INFO): + gen = TrainingDataGenerator(minimal_config, model_config["ddm"]) + + assert gen.generator_config["pipeline"]["n_cpus"] == 4 + assert "Resolved 'n_cpus'='all' to 4 cores." in caplog.text + + def test_n_cpus_all_resolution_via_os_cpu_count( + self, minimal_config, monkeypatch, caplog + ): + """Test successful resolution via os.cpu_count when affinity is unavailable.""" + import logging + import os + + minimal_config["pipeline"]["n_cpus"] = "all" + + def raise_exception(*args, **kwargs): + raise Exception("Mocked exception") + + if hasattr(os, "sched_getaffinity"): + monkeypatch.setattr(os, "sched_getaffinity", raise_exception) + if hasattr(os, "process_cpu_count"): + monkeypatch.setattr(os, "process_cpu_count", raise_exception) + + monkeypatch.setattr(os, "cpu_count", lambda: 8) + + with caplog.at_level(logging.INFO): + gen = TrainingDataGenerator(minimal_config, model_config["ddm"]) + + assert gen.generator_config["pipeline"]["n_cpus"] == 8 + assert "Resolved 'n_cpus'='all' to 8 cores." in caplog.text + class TestTrainingDataGeneratorPipelineIntegration: """Test TrainingDataGenerator integration with custom pipelines."""