Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
a8e6f36
Fix test suite failures for serial (DS) and low-budget (NGOpt) optimi…
suraj-2309 Nov 27, 2025
619bf09
Update sbi support to recent versions
mstimberg Dec 11, 2025
f64107a
Test Python versions according to SPEC-0
mstimberg Dec 11, 2025
ec32b03
Use uv for Python env
mstimberg Dec 11, 2025
f3ad815
Remove `nlopt` (no 3.14 wheels)
mstimberg Dec 11, 2025
a293d32
Update dependencies
mstimberg Dec 11, 2025
22c92e4
Increase timeout for tests
mstimberg Dec 17, 2025
d2bd0ce
Do not install sbi and nlopt on Python 3.14
mstimberg Jan 5, 2026
e261569
Remove fcmaes (optional) dependency
mstimberg Jan 5, 2026
6d82221
Use a simple matrix for testing
mstimberg Jan 5, 2026
e4d3a58
Do not test on 3.14, do not test efel
mstimberg Jan 6, 2026
509d2e7
fix typo in extra dependency
mstimberg Jan 7, 2026
40fe482
Install dev dependencies and lint
mstimberg Jan 7, 2026
7381191
Do not test optional algorithms
mstimberg Jan 7, 2026
630fa00
Verbose install for tests
mstimberg Jan 7, 2026
5dd154a
Override torch dependency for now
mstimberg Jan 7, 2026
e6cfc23
Skip tests requiring efel if not installed
mstimberg Jan 7, 2026
6697162
Fix `load_posterior` for torch≥2.6
mstimberg Jan 12, 2026
b2d382e
Avoid numpy 2.4
mstimberg Jan 12, 2026
643a05a
Use coveralls action and increase test timeout
mstimberg Jan 12, 2026
318d00b
Use lcov format for coverage report
mstimberg Jan 12, 2026
1733e5b
Test with efel again
mstimberg Jan 13, 2026
6efe40d
Check whether things work for torch < 2.6
mstimberg Jan 13, 2026
2d83862
Override pytorch dependency for newer Python versions
mstimberg Jan 14, 2026
6e03051
Override sbi's upper version limit for torch
mstimberg Jan 14, 2026
e2776cf
Test again with Python 3.14
mstimberg Jan 14, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 34 additions & 22 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,56 @@ name: Tests
on: [push, pull_request]

jobs:
get_python_versions:
name: "Determine Python versions"
runs-on: ubuntu-latest
permissions: {}
outputs:
min-python: ${{ steps.nep29.outputs.min-python }}
max-python: "${{ steps.nep29.outputs.max-python }}"
steps:
- name: "calculate versions according to SPEC-0"
id: nep29
uses: mstimberg/github-calc-nep29@a73481e4e8488a5fa0b3be70a385cc5206a261ba # v0.7
with:
token: ${{ secrets.GITHUB_TOKEN }}
# Match SPEC-0
deprecate-python-after: 36
min-python-releases: 0

build:

runs-on: ubuntu-latest
needs: [get_python_versions]
name: "Python ${{ matrix.python-version }} (latest Brian: ${{ matrix.latest-brian }})"
strategy:
fail-fast: false
matrix:
python-version: [3.8, 3.11]
latest-brian: [true, false]
python-version: ["${{ needs.get_python_versions.outputs.min-python }}", "${{ needs.get_python_versions.outputs.max-python }}"]
latest-brian: [false, true]

steps:
- name: Checkout Repository
uses: actions/checkout@v3
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
with:
fetch-depth: 0
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
with:
python-version: ${{ matrix.python-version }}
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
enable-cache: true
- name: Install brian2modelfitting
run: |
python -m pip install --upgrade pip wheel
python -m pip install flake8 pytest-coverage pytest-timeout coveralls
python -m pip install ".[all]"
run: uv sync -v --extra skopt --extra sbi --extra efel --extra test --dev # Not testing efel for now, since incompatible with numpy 2
- name: Update to latest Brian development version
run: python -m pip install -i https://test.pypi.org/simple/ --pre --upgrade Brian2
run: uv pip install -i https://test.pypi.org/simple/ --pre --upgrade Brian2
if: ${{ matrix.latest-brian }}
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Check for syntax errors and undefined names
run: uv run ruff check . --select=E9,F63,F7,F82
- name: Test with pytest
run: |
pytest --timeout=60 --cov=brian2modelfitting
uv run --no-sync --frozen pytest --timeout=240 --cov=brian2modelfitting --cov-report=lcov
- name: Upload coverage to coveralls
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
if: ${{ matrix.python-version == '3.8' && !matrix.latest-brian }}
run: coveralls --service=github

uses: coverallsapp/github-action@648a8eb78e6d50909eff900e4ec85cab4524a45b # v2.3.6
if: ${{ matrix.python-version == needs.get_python_versions.outputs.min-python && ! matrix.latest-brian}}
50 changes: 30 additions & 20 deletions brian2modelfitting/inferencer.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,49 @@
"""
Module to perform simulation-based inference with the ``sbi`` library.
"""
import warnings
from numbers import Number
from typing import Mapping
import warnings

import numpy as np
from brian2.core.functions import Function
from brian2.core.namespace import get_local_namespace
from brian2.core.network import Network
from brian2.devices.cpp_standalone.device import CPPStandaloneDevice
from brian2.devices.device import get_device, device
from brian2.devices.device import device, get_device
from brian2.equations.equations import Equations
from brian2.groups.neurongroup import NeuronGroup
from brian2.input.timedarray import TimedArray
from brian2.monitors.spikemonitor import SpikeMonitor
from brian2.monitors.statemonitor import StateMonitor
from brian2.units.fundamentalunits import (DIMENSIONLESS,
fail_for_dimension_mismatch,
get_dimensions,
Quantity)
from brian2.units.fundamentalunits import (
DIMENSIONLESS,
Quantity,
fail_for_dimension_mismatch,
get_dimensions,
)
from brian2.utils.logger import get_logger

from brian2modelfitting.fitter import get_spikes
import numpy as np

try:
import sbi
import torch
except ImportError:
sbi = None
torch = None

from .base import (handle_input_args,
handle_output_args,
handle_param_init,
input_equations,
output_equations,
output_dims)
from .simulator import RuntimeSimulator, CPPStandaloneSimulator
from .base import (
handle_input_args,
handle_output_args,
handle_param_init,
input_equations,
output_dims,
output_equations,
)
from .simulator import CPPStandaloneSimulator, RuntimeSimulator
from .utils import tqdm


logger = get_logger(__name__)


Expand Down Expand Up @@ -606,9 +611,7 @@ def init_inference(self, inference_method, density_estimator_model, prior,
Instantiated inference object.
"""
import sbi.inference
from sbi.utils.get_nn_models import (posterior_nn,
likelihood_nn,
classifier_nn)
from sbi.neural_nets import classifier_nn, likelihood_nn, posterior_nn
try:
inference_method = str.upper(inference_method)
inference_method_fun = getattr(sbi.inference, inference_method)
Expand Down Expand Up @@ -960,8 +963,15 @@ def load_posterior(self, f):
Loaded neural posterior with defined method family, density
estimator state dictionary, the prior over parameters and
the output shape of the simulator.
"""
p = torch.load(f)

Notes
-----
Only use this function to load files from trusted sources. It will
call `torch.load` with ``weights_only=False``, potentially resulting
in arbitrary code execution. See
https://pytorch.org/docs/stable/generated/torch.load.html
"""
p = torch.load(f, weights_only=False)
self.posterior = p
return p

Expand Down
1 change: 0 additions & 1 deletion brian2modelfitting/tests/test_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,6 @@ def test_infer_step(setup_full):
n_samples=10,
inference=inference)
assert isinstance(posterior, DirectPosterior)
assert_equal(np.array(posterior._x_shape), np.array([1, 5]))


def test_infer_step_errors(setup_full):
Expand Down
3 changes: 3 additions & 0 deletions brian2modelfitting/tests/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def test_get_errors_gamma():

@pytest.mark.parametrize("parallel_processes", [0, 2, -1, -2]) # only testing that it works at all
def test_calc_EFL(parallel_processes):
pytest.importorskip("efel")
# "voltage traces" that are constant at -70*mV, -60mV, -50mV, -40mV for
# 50ms each.
dt = 1*ms
Expand All @@ -239,6 +240,7 @@ def test_calc_EFL(parallel_processes):


def test_get_features_feature_metric():
pytest.importorskip("efel")
# "voltage traces" that are constant at -70*mV, -60mV, -50mV, -40mV for
# 50ms each.
voltage_target = np.ones((2, 200)) * np.repeat([-70, -60, -50, -40], 50) * mV
Expand Down Expand Up @@ -276,6 +278,7 @@ def test_get_features_feature_metric():


def test_get_errors_feature_metric():
pytest.importorskip("efel")
# Fake results
features = [{'feature1': np.array([0, 0.5]),
'feature2': np.array([1, 2])},
Expand Down
61 changes: 46 additions & 15 deletions brian2modelfitting/tests/test_modelfitting_tracefitter.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,46 @@
'''
Test the modelfitting module
'''
import pytest
import brian2.numpy_ as np # for unit-awareness
import pandas as pd
import pytest
import scipy.optimize

from numpy.testing import assert_equal, assert_almost_equal
from brian2 import (zeros, Equations, NeuronGroup, StateMonitor, TimedArray,
nS, mV, volt, ms, pA, pF, Quantity, set_device, get_device,
Network, have_same_dimensions, DimensionMismatchError)
from brian2.equations.equations import DIFFERENTIAL_EQUATION, SUBEXPRESSION
import brian2.numpy_ as np # for unit-awareness
from brian2modelfitting import (NevergradOptimizer, TraceFitter, MSEMetric,
OnlineTraceFitter, Simulator, Metric,
Optimizer, GammaFactor, FeatureMetric)
from brian2 import (
DimensionMismatchError,
Equations,
Network,
NeuronGroup,
Quantity,
StateMonitor,
TimedArray,
get_device,
have_same_dimensions,
ms,
mV,
nS,
pA,
pF,
set_device,
volt,
zeros,
)
from brian2.devices.device import reinit_devices, reset_device
from brian2.equations.equations import DIFFERENTIAL_EQUATION, SUBEXPRESSION
from numpy.testing import assert_almost_equal, assert_equal

from brian2modelfitting import (
FeatureMetric,
GammaFactor,
Metric,
MSEMetric,
NevergradOptimizer,
OnlineTraceFitter,
Optimizer,
Simulator,
TraceFitter,
)
from brian2modelfitting.fitter import get_param_dic


E = 40*mV
input_traces = zeros((10, 5))*volt
for i in range(5):
Expand Down Expand Up @@ -339,6 +362,8 @@ def test_tracefitter_fit_default_metric(setup):


from nevergrad.optimization import registry as nevergrad_registry


@pytest.mark.parametrize('method', sorted(nevergrad_registry.keys()))
def test_fitter_fit_methods(method):
dt = 0.01 * ms
Expand All @@ -347,15 +372,20 @@ def test_fitter_fit_methods(method):
g : siemens (constant)
E : volt (constant)
''')
# Fix for optimizers that don't support parallelization (DS)
# or have small fixed budgets (NGOptSingle)
n_samples = 30
if any(name in method for name in ['DS', 'NGOptSingle']):
n_samples = 1
tf = TraceFitter(dt=dt,
model=model,
input_var='v',
output_var='I',
input=input_traces,
output=output_traces,
n_samples=30)
n_samples=n_samples)
# Skip a few methods that seem to hang due to multi-threading deadlocks (?) or simply take very long
skip = ['BO', 'ParaPortfolio', 'BAR', 'MultiBFGS', 'MultiCobyla', 'MultiSQP', 'NgIohRW', 'F3SQPCMA']
skip = ['MultiDS', 'BO', 'ParaPortfolio', 'BAR', 'MultiBFGS', 'MultiCobyla', 'MultiSQP', 'NgIohRW', 'F3SQPCMA']
if any(s in method for s in skip):
pytest.skip(f'Skipping method {method}')

Expand Down Expand Up @@ -702,6 +732,7 @@ def test_fitter_refine_reuse_tsteps_multiobjective(setup_constant_multiobjective


def test_fitter_refine_errors(setup):
pytest.importorskip("efel")
dt, tf = setup
with pytest.raises(TypeError):
# Missing start parameter
Expand Down Expand Up @@ -1136,7 +1167,7 @@ def test_multiobjective_basic(setup_multiobjective):

def test_multiobjective_no_units(setup_multiobjective_no_units):
dt, tf = setup_multiobjective_no_units
result, error = tf.fit(n_rounds=20,
result, error = tf.fit(n_rounds=30,
metric={'var1': MSEMetric(t_start=50*ms),
'var2': MSEMetric(t_start=50*ms, normalization=0.001)},
optimizer=n_opt,
Expand Down
16 changes: 11 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ authors = [
{name = 'Marcel Stimberg'},
{name ='Romain Brette'}
]
requires-python = '>=3.8'
requires-python = '>=3.12'
dependencies = [
'numpy>=1.21',
'numpy>=2.0,<2.4', # Nevergrad depends on library incompatible with 2.4
'brian2>=2.2',
'nevergrad>=0.4',
'scikit-learn>=0.22',
Expand All @@ -30,16 +30,16 @@ classifiers = [
]

[project.optional-dependencies]
test = ['pytest']
test = ['pytest', 'pytest-coverage', 'pytest-timeout']
docs = ['sphinx>=1.8']
algos = [ # additional optimizers for nevergrad
'cma>=3.0', 'fcmaes', 'loguru', # loguru seems to be an undeclared dependency of fcmaes
'cma>=3.0',
'nlopt',
'poap',
'ConfigSpace']
skopt = ['scikit-optimize']
efel = ['efel']
sbi = ['sbi>=0.16.0']
sbi = ['sbi>=0.23.0']
all = ['brian2modelfitting[test]',
'brian2modelfitting[docs]',
'brian2modelfitting[algos]',
Expand All @@ -48,6 +48,12 @@ all = ['brian2modelfitting[test]',
'brian2modelfitting[sbi]'
]

[dependency-groups]
dev = ['ruff']

[tool.uv]
override-dependencies = ["torch"] # Remove sbi's upper dependency

[project.urls]
Documentation ='https://brian2modelfitting.readthedocs.io/'
Source = 'https://github.com/brian-team/brian2modelfitting'
Expand Down