Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "ssm-simulators"
version = "0.12.0"
version = "0.12.1"
description = "SSMS is a package collecting simulators and training data generators for cognitive science, neuroscience, and approximate bayesian computation"
authors = [
{ name = "Alexander Fengler", email = "alexander_fengler@brown.edu" },
Expand Down
42 changes: 40 additions & 2 deletions ssms/dataset_generators/lan_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""

import logging
import os
import uuid
import warnings
from copy import deepcopy
Expand Down Expand Up @@ -314,12 +315,49 @@ def _get_ncpus(self):
)

if n_cpus_config == "all":
n_cpus = psutil.cpu_count(logical=False)
n_cpus = None

# 1. Try process affinity first (best for SLURM/containers)
try:
if hasattr(os, "sched_getaffinity"):
n_cpus = len(os.sched_getaffinity(0))
elif hasattr(
os, "process_cpu_count"
): # pragma: no cover (Python 3.13+ without sched_getaffinity)
n_cpus = os.process_cpu_count()
except Exception as e:
logger.debug("Could not get process affinity: %s", e)

# 2. Fallback to OS logical core count
if n_cpus is None:
try:
n_cpus = os.cpu_count()
except Exception as e:
logger.debug("os.cpu_count() failed: %s", e)

# 3. Fallback to psutil
if n_cpus is None:
try:
n_cpus = psutil.cpu_count(logical=False) or psutil.cpu_count(
logical=True
)
except Exception as e:
logger.debug("psutil.cpu_count() failed: %s", e)

# 4. Fail fast if we STILL don't know
if n_cpus is None:
raise RuntimeError(
"Could not determine CPU count automatically. "
"Please specify 'n_cpus' explicitly as an integer in your config."
)

# 5. Informative logging
logger.info("Resolved 'n_cpus'='all' to %d cores.", n_cpus)
else:
n_cpus = n_cpus_config

# Update nested config
if "pipeline" not in self.generator_config:
if "pipeline" not in self.generator_config: # pragma: no cover
self.generator_config["pipeline"] = {}
self.generator_config["pipeline"]["n_cpus"] = n_cpus

Expand Down
66 changes: 66 additions & 0 deletions tests/dataset_generators/test_training_data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,72 @@ def test_generate_data_likelihood_validity(self, minimal_config):
# (allowing some numerical edge cases)
assert np.median(lan_labels) < 0

def test_n_cpus_all_resolution_fallback(self, minimal_config, monkeypatch):
"""Test that n_cpus='all' raises RuntimeError when all detection methods fail."""
import os
import psutil

minimal_config["pipeline"]["n_cpus"] = "all"

def raise_exception(*args, **kwargs):
raise Exception("Mocked exception")

if hasattr(os, "sched_getaffinity"):
monkeypatch.setattr(os, "sched_getaffinity", raise_exception)
if hasattr(os, "process_cpu_count"):
monkeypatch.setattr(os, "process_cpu_count", raise_exception)
monkeypatch.setattr(os, "cpu_count", raise_exception)
monkeypatch.setattr(psutil, "cpu_count", raise_exception)

with pytest.raises(
RuntimeError, match="Could not determine CPU count automatically"
):
_ = TrainingDataGenerator(minimal_config, model_config["ddm"])

def test_n_cpus_all_resolution_via_affinity(
self, minimal_config, monkeypatch, caplog
):
"""Test successful resolution via os.sched_getaffinity."""
import logging
import os

minimal_config["pipeline"]["n_cpus"] = "all"

monkeypatch.setattr(
os, "sched_getaffinity", lambda pid: {0, 1, 2, 3}, raising=False
)

with caplog.at_level(logging.INFO):
gen = TrainingDataGenerator(minimal_config, model_config["ddm"])

assert gen.generator_config["pipeline"]["n_cpus"] == 4
assert "Resolved 'n_cpus'='all' to 4 cores." in caplog.text

def test_n_cpus_all_resolution_via_os_cpu_count(
self, minimal_config, monkeypatch, caplog
):
"""Test successful resolution via os.cpu_count when affinity is unavailable."""
import logging
import os

minimal_config["pipeline"]["n_cpus"] = "all"

def raise_exception(*args, **kwargs):
raise Exception("Mocked exception")

if hasattr(os, "sched_getaffinity"):
monkeypatch.setattr(os, "sched_getaffinity", raise_exception)
if hasattr(os, "process_cpu_count"):
monkeypatch.setattr(os, "process_cpu_count", raise_exception)

monkeypatch.setattr(os, "cpu_count", lambda: 8)

with caplog.at_level(logging.INFO):
gen = TrainingDataGenerator(minimal_config, model_config["ddm"])

assert gen.generator_config["pipeline"]["n_cpus"] == 8
assert "Resolved 'n_cpus'='all' to 8 cores." in caplog.text


class TestTrainingDataGeneratorPipelineIntegration:
"""Test TrainingDataGenerator integration with custom pipelines."""
Expand Down
Loading