Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Upload Python Package - ezmsg-event
name: Upload Python Package

on:
release:
Expand All @@ -17,7 +17,7 @@ jobs:
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v2
uses: astral-sh/setup-uv@v6

- name: Build Package
run: uv build
Expand Down
8 changes: 5 additions & 3 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@ on:
push:
branches: [main]
pull_request:
branches: [main]
branches:
- main
- dev
workflow_dispatch:

jobs:
build:
strategy:
matrix:
python-version: ["3.10.15", "3.11", "3.12"]
python-version: ["3.10.15", "3.11", "3.12", "3.13"]
os:
- "ubuntu-latest"
runs-on: ${{matrix.os}}
Expand All @@ -20,7 +22,7 @@ jobs:
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v3
uses: astral-sh/setup-uv@v6

- name: Install the project
run: uv sync --python ${{ matrix.python-version }}
Expand Down
9 changes: 9 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.3
hooks:
# Run the linter.
- id: ruff
args: [ --fix ]
# Run the formatter.
- id: ruff-format
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,33 @@
# ezmsg-event

ezmsg namespace package for working with signal events like neural spikes and heartbeats

## Installation

```bash
pip install ezmsg-event
```

## Dependencies

- `ezmsg`
- `numpy`

## Usage

See the `examples` folder for usage examples.

## Development

We use [`uv`](https://docs.astral.sh/uv/getting-started/installation/) for development.

1. Install [`uv`](https://docs.astral.sh/uv/getting-started/installation/) if not already installed.
2. Fork this repository and clone your fork locally.
3. Open a terminal and `cd` to the cloned folder.
4. Run `uv sync` to create a `.venv` and install dependencies.
5. (Optional) Install pre-commit hooks: `uv run pre-commit install`
6. After making changes, run the test suite: `uv run pytest tests`

## License

MIT License - see [LICENSE](LICENSE) for details.
19 changes: 19 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ readme = "README.md"
requires-python = ">=3.10.15"
dynamic = ["version"]
dependencies = [
"ezmsg-baseproc",
"ezmsg-sigproc>=2.4.0",
"ezmsg>=3.6.1",
"sparse>=0.17.0",
Expand All @@ -17,6 +18,7 @@ dependencies = [
[dependency-groups]
dev = [
"typer",
"pre-commit>=4.0.0",
{include-group = "lint"},
{include-group = "test"},
]
Expand Down Expand Up @@ -46,3 +48,20 @@ version-file = "src/ezmsg/event/__version__.py"

[tool.hatch.build.targets.wheel]
packages = ["src/ezmsg"]

[tool.ruff]
line-length = 120
target-version = "py310"
# Exclude auto-generated files
exclude = ["*/__version__.py"]

[tool.ruff.lint]
select = ["E", "F", "I", "W"]

[tool.ruff.lint.isort]
known-first-party = ["ezmsg.event"]
known-third-party = ["ezmsg"]

[tool.uv.sources]
# Uncomment to use development version of ezmsg from git
#ezmsg = { git = "https://github.com/ezmsg-org/ezmsg.git", branch = "dev" }
5 changes: 0 additions & 5 deletions src/ezmsg/event/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1 @@
from .__version__ import __version__ as __version__

from .rate import Rate as Rate
from .rate import EventRate as EventRate
from .binned import BinnedEventAggregator as BinnedEventAggregator
from .binned import BinnedEventAggregatorSettings as BinnedEventAggregatorSettings
43 changes: 12 additions & 31 deletions src/ezmsg/event/binned.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import ezmsg.core as ez
import numpy as np
import numpy.typing as npt

from ezmsg.sigproc.base import (
from ezmsg.baseproc import (
BaseStatefulTransformer,
BaseTransformerUnit,
processor_state,
Expand Down Expand Up @@ -33,9 +32,7 @@ class BinnedEventAggregatorState:


class BinnedEventAggregator(
BaseStatefulTransformer[
BinnedEventAggregatorSettings, AxisArray, AxisArray, BinnedEventAggregatorState
]
BaseStatefulTransformer[BinnedEventAggregatorSettings, AxisArray, AxisArray, BinnedEventAggregatorState]
):
def _hash_message(self, message: AxisArray) -> int:
targ_ax_idx = message.get_axis_idx(self.settings.axis)
Expand All @@ -45,9 +42,7 @@ def _hash_message(self, message: AxisArray) -> int:
def _reset_state(self, message: AxisArray) -> None:
self._state.n_overflow = 0
targ_axis_idx = message.get_axis_idx(self.settings.axis)
buff_shape = (
message.data.shape[:targ_axis_idx] + message.data.shape[targ_axis_idx + 1 :]
)
buff_shape = message.data.shape[:targ_axis_idx] + message.data.shape[targ_axis_idx + 1 :]
self._state.counts_in_overflow = np.zeros(buff_shape, dtype=np.int64)

def _process(self, message: AxisArray) -> AxisArray:
Expand All @@ -63,20 +58,17 @@ def _process(self, message: AxisArray) -> AxisArray:
n_prev_overflow = self._state.n_overflow

if self._state.n_overflow > 0:
# Calculate how many samples from the input msg we can fit into the first bin, given the current overflow state
# Calculate how many samples from the input msg we can fit into the first bin,
# given the current overflow state
n_first = samples_per_bin - self._state.n_overflow
# Sum the number of samples in the first bin then add to self._state.counts_in_overflow
var_slice[targ_ax_idx] = slice(0, n_first)
first_bin_counts = (
message.data[tuple(var_slice)].sum(axis=targ_ax_idx).todense()
)
first_bin_counts = message.data[tuple(var_slice)].sum(axis=targ_ax_idx).todense()
first_bin_counts += self._state.counts_in_overflow
else:
n_first = 0
first_bin_counts = self._state.counts_in_overflow
assert np.all(first_bin_counts == 0), (
"Overflow state should be zeroed out from previous iteration."
)
assert np.all(first_bin_counts == 0), "Overflow state should be zeroed out from previous iteration."

# Calculate how many samples remain after the first bin
n_remaining = message.data.shape[targ_ax_idx] - n_first
Expand All @@ -93,20 +85,14 @@ def _process(self, message: AxisArray) -> AxisArray:
+ (n_full_bins, samples_per_bin)
+ full_bins_data.shape[targ_ax_idx + 1 :]
)
middle_bin_counts = (
full_bins_data.reshape(new_shape).sum(axis=targ_ax_idx + 1).todense()
)
middle_bin_counts = full_bins_data.reshape(new_shape).sum(axis=targ_ax_idx + 1).todense()

# Prepare output
if self._state.n_overflow > 0:
first_bin_counts = first_bin_counts.reshape(
first_bin_counts.shape[:targ_ax_idx]
+ (1,)
+ first_bin_counts.shape[targ_ax_idx:]
)
output_data = np.concatenate(
[first_bin_counts, middle_bin_counts], axis=targ_ax_idx
first_bin_counts.shape[:targ_ax_idx] + (1,) + first_bin_counts.shape[targ_ax_idx:]
)
output_data = np.concatenate([first_bin_counts, middle_bin_counts], axis=targ_ax_idx)
else:
output_data = middle_bin_counts

Expand All @@ -123,10 +109,7 @@ def _process(self, message: AxisArray) -> AxisArray:
out_msg = replace(
message,
data=output_data,
axes={
k: v if k != self.settings.axis else out_axis
for k, v in message.axes.items()
},
axes={k: v if k != self.settings.axis else out_axis for k, v in message.axes.items()},
)

# Calculate and store the overflow state.
Expand All @@ -139,8 +122,6 @@ def _process(self, message: AxisArray) -> AxisArray:


class BinnedEventAggregatorUnit(
BaseTransformerUnit[
BinnedEventAggregatorSettings, AxisArray, AxisArray, BinnedEventAggregator
]
BaseTransformerUnit[BinnedEventAggregatorSettings, AxisArray, AxisArray, BinnedEventAggregator]
):
SETTINGS = BinnedEventAggregatorSettings
42 changes: 14 additions & 28 deletions src/ezmsg/event/eventsfromrates.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import numpy.typing as npt
import sparse
from ezmsg.sigproc.base import (
from ezmsg.baseproc import (
BaseStatefulTransformer,
BaseTransformerUnit,
processor_state,
Expand Down Expand Up @@ -161,9 +161,7 @@ class EventsFromRatesState:


class EventsFromRatesTransformer(
BaseStatefulTransformer[
EventsFromRatesSettings, AxisArray, AxisArray, EventsFromRatesState
]
BaseStatefulTransformer[EventsFromRatesSettings, AxisArray, AxisArray, EventsFromRatesState]
):
def _reset_state(self, message: AxisArray) -> None:
ch_ax = message.get_axis_idx("ch")
Expand All @@ -178,14 +176,10 @@ def _process(self, message: AxisArray) -> AxisArray:
total_samples = n_bins * int(bin_duration * self.settings.output_fs)

# Get rates array with shape (n_bins, n_channels), contiguous for numba
rates_array = (
message.data / bin_duration if self.settings.assume_counts else message.data
)
rates_array = message.data / bin_duration if self.settings.assume_counts else message.data
if time_ax != 0:
rates_array = np.moveaxis(rates_array, time_ax, 0)
rates_array = np.ascontiguousarray(
np.maximum(rates_array, self.settings.min_rate)
)
rates_array = np.ascontiguousarray(np.maximum(rates_array, self.settings.min_rate))
n_channels = rates_array.shape[1]

# Estimate max events per channel based on actual input rates
Expand All @@ -194,25 +188,21 @@ def _process(self, message: AxisArray) -> AxisArray:
max_events_per_channel = max(int(max_input_rate * total_time * 3) + 10, 20)

# Generate events using numba (parallel across channels)
all_event_samples, event_counts, accumulated_out, threshold_out = (
_generate_events_all_channels(
rates_array,
self.state.accumulated,
self.state.threshold,
bin_duration,
self.settings.output_fs,
max_events_per_channel,
)
all_event_samples, event_counts, accumulated_out, threshold_out = _generate_events_all_channels(
rates_array,
self.state.accumulated,
self.state.threshold,
bin_duration,
self.settings.output_fs,
max_events_per_channel,
)

# Update state for next chunk
self.state.accumulated = accumulated_out
self.state.threshold = threshold_out

# Flatten per-channel arrays into coordinate arrays
event_samples, event_channels = _flatten_events_unsorted(
all_event_samples, event_counts
)
event_samples, event_channels = _flatten_events_unsorted(all_event_samples, event_counts)

# Build sparse array (COO handles sorting internally)
if len(event_samples) > 0:
Expand All @@ -230,9 +220,7 @@ def _process(self, message: AxisArray) -> AxisArray:
)

if self.settings.layout == "gcxs":
event_array = sparse.GCXS.from_coo(
event_array, compressed_axes=self.settings.compress_dims
)
event_array = sparse.GCXS.from_coo(event_array, compressed_axes=self.settings.compress_dims)

return replace(
message,
Expand All @@ -246,8 +234,6 @@ def _process(self, message: AxisArray) -> AxisArray:


class EventsFromRatesUnit(
BaseTransformerUnit[
EventsFromRatesSettings, AxisArray, AxisArray, EventsFromRatesTransformer
]
BaseTransformerUnit[EventsFromRatesSettings, AxisArray, AxisArray, EventsFromRatesTransformer]
):
SETTINGS = EventsFromRatesSettings
Loading