Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,16 @@ jobs:
build:
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12"]
python-version: ["3.10.15", "3.11", "3.12"]
os:
- "ubuntu-latest"
- "windows-latest"
- "macos-latest"
runs-on: ${{matrix.os}}

steps:
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v2
uses: astral-sh/setup-uv@v3
with:
enable-cache: true
cache-dependency-glob: "uv.lock"
Expand All @@ -31,7 +29,7 @@ jobs:
run: uv python install ${{ matrix.python-version }}

- name: Install the project
run: uv sync --all-extras --dev
run: uv sync --all-extras

- name: Lint
run:
Expand Down
1 change: 0 additions & 1 deletion .python-version
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
3.10
14 changes: 6 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ authors = [
{ name = "Chadwick Boulay", email = "chadwick.boulay@gmail.com" },
]
readme = "README.md"
requires-python = ">=3.10"
requires-python = ">=3.10.15"
dynamic = ["version"]
dependencies = [
"ezmsg-sigproc>=1.5.0",
"ezmsg-sigproc>=1.6.0",
"numpy>=1.26",
"sparse>=0.15.4",
]
Expand All @@ -17,6 +17,10 @@ dependencies = [
test = [
"pytest>=8.3.3",
]
dev = [
"ruff>=0.6.8",
"typer>=0.12.5",
]

[build-system]
requires = ["hatchling", "hatch-vcs"]
Expand All @@ -30,9 +34,3 @@ version-file = "src/ezmsg/event/__version__.py"

[tool.hatch.build.targets.wheel]
packages = ["src/ezmsg"]

[tool.uv]
dev-dependencies = [
"ruff>=0.6.8",
"typer>=0.12.5",
]
22 changes: 10 additions & 12 deletions src/ezmsg/event/peak.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def threshold_crossing(
return_peak_val: bool = False,
auto_scale_tau: float = 0.0,
) -> typing.Generator[
typing.Union[typing.List[EventMessage], AxisArray], AxisArray, None
list[EventMessage] | AxisArray, AxisArray, None
]:
"""
Detect threshold crossing events.
Expand All @@ -52,31 +52,29 @@ def threshold_crossing(
msg_out = AxisArray(np.array([]), dims=[""])

# Initialize state variables
sample_shape: typing.Optional[typing.Tuple[int, ...]] = None
fs: typing.Optional[float] = None
sample_shape: tuple[int, ...] | None = None
fs: float | None = None
max_width: int = 0
min_width: int = 1 # Consider making this a parameter.
refrac_width: int = 0

scaler: typing.Optional[typing.Generator[AxisArray, AxisArray, None]] = None
scaler: typing.Generator[AxisArray, AxisArray, None] | None = None
# adaptive z-scoring.
# TODO: This sample-by-sample adaptation is probably overkill. ezmsg-sigproc should add chunk-wise scaler updating.

_overs: typing.Optional[npt.NDArray] = (
None # (n_feats, <=max_width) int == -1 or +1
)
_overs: npt.NDArray | None = None # (n_feats, <=max_width) int == -1 or +1
# Trailing buffer to track whether the previous sample(s) were past threshold.

_data: typing.Optional[npt.NDArray] = None # (n_feats, <=max_width) in_dtype
_data: npt.NDArray | None = None # (n_feats, <=max_width) in_dtype
# Trailing buffer in case peak spans sample chunks. Only used if align_on_peak or return_peak_val.

_data_raw: typing.Optional[npt.NDArray] = None # (n_feats, <=max_width) in_dtype
_data_raw: npt.NDArray | None = None # (n_feats, <=max_width) in_dtype
# Only used if return_peak_val and scaler is not None

_elapsed: typing.Optional[npt.NDArray] = None # (n_feats,) int
_elapsed: npt.NDArray | None = None # (n_feats,) int
# Number of samples since last event. Used to enforce refractory period across iterations.
#
# _n_skip: typing.Optional[npt.NDArray] = None # (n_feats,) int
# _n_skip: npt.NDArray | None = None # (n_feats,) int

while True:
msg_in: AxisArray = yield msg_out
Expand Down Expand Up @@ -114,7 +112,7 @@ def threshold_crossing(
)

# Optionally scale data
data_raw: typing.Optional[npt.NDArray] = None
data_raw: npt.NDArray | None = None
if scaler is not None:
if return_peak_val:
data_raw = msg_in.data.copy()
Expand Down
4 changes: 2 additions & 2 deletions src/ezmsg/event/rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def event_rate(
window_shift=bin_duration,
zero_pad_until="none",
)
out_dims: typing.Optional[typing.List[str]] = None
out_axes: typing.Optional[typing.Dict[str, AxisArray.Axis]] = None
out_dims: list[str] | None = None
out_axes: dict[str, AxisArray.Axis] | None = None

while True:
msg_in: AxisArray = yield msg_out
Expand Down
24 changes: 12 additions & 12 deletions src/ezmsg/event/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

@consumer
def windowing(
axis: typing.Optional[str] = None,
axis: str | None = None,
newaxis: str = "win",
window_dur: typing.Optional[float] = None,
window_shift: typing.Optional[float] = None,
window_dur: float | None = None,
window_shift: float | None = None,
zero_pad_until: str = "input",
) -> typing.Generator[AxisArray, AxisArray, None]:
"""
Expand Down Expand Up @@ -70,14 +70,14 @@ def windowing(
msg_out = AxisArray(np.array([]), dims=[""])

# State variables
buffer: typing.Optional[sparse.SparseArray] = None
window_samples: typing.Optional[int] = None
window_shift_samples: typing.Optional[int] = None
buffer: sparse.SparseArray | None = None
window_samples: int | None = None
window_shift_samples: int | None = None
shift_deficit: int = 0
b_1to1 = window_shift is None
newaxis_warned: bool = b_1to1
out_newaxis: typing.Optional[AxisArray.Axis] = None
out_dims: typing.Optional[typing.List[str]] = None
out_newaxis: AxisArray.Axis | None = None
out_dims: list[str] | None = None

check_inputs = {"samp_shape": None, "fs": None, "key": None}

Expand Down Expand Up @@ -203,10 +203,10 @@ def windowing(


class WindowSettings(ez.Settings):
axis: typing.Optional[str] = None
newaxis: typing.Optional[str] = None # new axis for output. No new axes if None
window_dur: typing.Optional[float] = None # Sec. passthrough if None
window_shift: typing.Optional[float] = None # Sec. Use "1:1 mode" if None
axis: str | None = None
newaxis: str | None = None # new axis for output. No new axes if None
window_dur: float | None = None # Sec. passthrough if None
window_shift: float | None = None # Sec. Use "1:1 mode" if None
zero_pad_until: str = "full" # "full", "shift", "input", "none"


Expand Down
3 changes: 1 addition & 2 deletions tests/test_peak.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
from pathlib import Path
import tempfile
import typing

import numpy as np
import pytest
Expand Down Expand Up @@ -86,7 +85,7 @@ def test_threshold_crossing(return_peak_val: bool):
assert np.array_equal(feat_inds, exp_feat_inds)


def get_test_fn(test_name: typing.Optional[str] = None, extension: str = "txt") -> Path:
def get_test_fn(test_name: str | None = None, extension: str = "txt") -> Path:
"""PYTEST compatible temporary test file creator"""
if test_name is None:
test_name = os.environ.get("PYTEST_CURRENT_TEST")
Expand Down
4 changes: 1 addition & 3 deletions tests/test_window.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import typing

import numpy as np
import pytest
import sparse
Expand All @@ -13,7 +11,7 @@
@pytest.mark.parametrize("zero_pad", ["input", "shift", "none"])
def test_sparse_window(
win_dur: float,
win_shift: typing.Optional[float],
win_shift: float | None,
zero_pad: str,
):
fs = 100.0
Expand Down
Loading