From 4786322aa3ed54b7716c3ec9ca2ba1cc9595f023 Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Tue, 2 Dec 2025 02:30:49 -0500 Subject: [PATCH] Updated API of module that makes sparse arrays dense. --- src/ezmsg/event/sparse.py | 31 +++++++++++-------------------- tests/test_sparse.py | 10 +++++----- 2 files changed, 16 insertions(+), 25 deletions(-) diff --git a/src/ezmsg/event/sparse.py b/src/ezmsg/event/sparse.py index 182fbb6..c4322c6 100644 --- a/src/ezmsg/event/sparse.py +++ b/src/ezmsg/event/sparse.py @@ -1,32 +1,23 @@ from dataclasses import replace -import typing -import numpy as np import ezmsg.core as ez -from ezmsg.sigproc.base import GenAxisArray -from ezmsg.util.generator import consumer from ezmsg.util.messages.axisarray import AxisArray - - -@consumer -def to_dense() -> typing.Generator[AxisArray, AxisArray, None]: - msg_out = AxisArray(np.array([]), dims=[""]) - while True: - msg_in: AxisArray = yield msg_out - if hasattr(msg_in.data, "todense"): - msg_out = replace(msg_in, data=msg_in.data.todense()) - else: - msg_out = msg_in +from ezmsg.sigproc.base import BaseTransformer, BaseTransformerUnit class DensifySettings(ez.Settings): pass -class Densify(GenAxisArray): - """:obj:`Unit` for :obj:`bandpower`.""" +class DensifyTransformer(BaseTransformer[DensifySettings, AxisArray, AxisArray]): + def _process(self, message: AxisArray) -> AxisArray: + if hasattr(message.data, "todense"): + return replace(message, data=message.data.todense()) + else: + return message - SETTINGS = DensifySettings - def construct_generator(self): - self.STATE.gen = to_dense() +class DensifyUnit( + BaseTransformerUnit[DensifySettings, AxisArray, AxisArray, DensifyTransformer] +): + SETTINGS = DensifySettings diff --git a/tests/test_sparse.py b/tests/test_sparse.py index b5a308a..fd8c5cc 100644 --- a/tests/test_sparse.py +++ b/tests/test_sparse.py @@ -3,21 +3,21 @@ import sparse from ezmsg.util.messages.axisarray import AxisArray -from ezmsg.event.sparse import to_dense +from ezmsg.event.sparse import DensifyTransformer @pytest.mark.parametrize("sparse_input", [True, False]) -def test_to_dense(sparse_input: bool): +def test_densify(sparse_input: bool): arr_shape = (100, 50, 30) if sparse_input: rng = np.random.default_rng() data = sparse.random(arr_shape, density=0.1, random_state=rng) else: data = np.random.rand(*arr_shape) - in_msg = AxisArray(data=data, dims=["time", "ch", "freq"], key="test_to_dense") + in_msg = AxisArray(data=data, dims=["time", "ch", "freq"], key="test_densify") - proc = to_dense() - out_msg = proc.send(in_msg) + transformer = DensifyTransformer() + out_msg = transformer(in_msg) assert out_msg.data.shape == in_msg.data.shape assert isinstance(out_msg.data, np.ndarray) if sparse_input: