From 83f7c814a4cec4e21f0127fb85bb098ef7568f63 Mon Sep 17 00:00:00 2001 From: Igor10798 Date: Tue, 10 May 2022 13:43:59 +0200 Subject: [PATCH 01/19] setup slicing --- bwpy/__init__.py | 55 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/bwpy/__init__.py b/bwpy/__init__.py index 1f68a23..fc2f8c4 100644 --- a/bwpy/__init__.py +++ b/bwpy/__init__.py @@ -3,6 +3,7 @@ import numpy as np from ._hdf_annotations import requires_write_access from ._channels import Channel, ChannelGroup +import functools __version__ = "0.0.1a0" @@ -107,10 +108,62 @@ def guid(self): return self.attrs["GUID"].decode() -class BRWFile(File): +class _Slicer: + def __init__(self, slice): + self._slice = slice + + +class _TimeSlicer(_Slicer): + def __getitem__(self, instruction): + return self._slice._time_slice(instruction) + + +class _ChannelSlicer(_Slicer): + def __getitem__(self, instruction): + return self._slice._channel_slice(instruction) + + +class _Slice: + def __init__(self, file): + self._file = file + + @property + @functools.cache + def t(self): + return _TimeSlicer(self) + + @property + @functools.cache + def ch(self): + return _ChannelSlicer(self) + + def _time_slice(self, instruction): + print("Requested time slice", instruction) + + def _channel_slice(self, instruction): + print("Requested channel slice", instruction) + + +class BRWFile(File, _Slice): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + _Slice.__init__(self, self) + def _get_descr_prefix(self): return "BRW-File Level3" + @property + def channels(self): + return self['/3BRecInfo/3BMeaStreams/Raw/Chs'] + + @property + def data(self): + return self['/3BData/Raw'] + + @property + def channels_layout(self): + return self['/3BData/Raw'] + class BXRFile(File): @property From 6efeebe6cc129e2ef64ce7c5f9f7ac0166f03a4a Mon Sep 17 00:00:00 2001 From: Igor10798 Date: Fri, 13 May 2022 14:35:44 +0200 Subject: [PATCH 02/19] testing completed --- bwpy/__init__.py | 85 ++++++++++++++++++++++++++++++++++++++----- setup.py | 2 +- tests/test_file.py | 2 +- tests/test_slicing.py | 50 +++++++++++++++++++++++++ 4 files changed, 127 insertions(+), 12 deletions(-) create mode 100644 tests/test_slicing.py diff --git a/bwpy/__init__.py b/bwpy/__init__.py index fc2f8c4..24e85c8 100644 --- a/bwpy/__init__.py +++ b/bwpy/__init__.py @@ -31,9 +31,11 @@ def _establish_type(self): if self.description.startswith("BRW"): self.__class__ = BRWFile self._type = "brw" + self.__post_init__() if self.description.startswith("BXR"): self.__class__ = BXRFile self._type = "bxr" + self.__post_init__() @property def description(self): @@ -99,6 +101,11 @@ def get_raw_recording_info(self): def get_raw_user_info(self): return self["3BUserInfo"] + def convert(self, dv): + step_v = self.signal_inversion * ((self.max_volt - self.min_volt) / 2 ** self.bit_depth) + v_offset = self.signal_inversion * self.min_volt + return dv * step_v + v_offset + @property def version(self): return self.attrs["Version"] @@ -115,17 +122,27 @@ def __init__(self, slice): class _TimeSlicer(_Slicer): def __getitem__(self, instruction): - return self._slice._time_slice(instruction) + slice = self._slice._time_slice(instruction) + #self._file.n_frames = + return slice class _ChannelSlicer(_Slicer): def __getitem__(self, instruction): - return self._slice._channel_slice(instruction) + slice = self._slice._channel_slice(instruction) + #self._file.n_channel = + return slice class _Slice: - def __init__(self, file): + def __init__(self, file, channels=None, time=None): self._file = file + if channels is None: + channels = file.channels[()].reshape(file.layout.shape) + self._channels = channels + if time is None: + time = slice(None) + self._time = time @property @functools.cache @@ -137,16 +154,57 @@ def t(self): def ch(self): return _ChannelSlicer(self) + @property + def channels(self): + return self._channels + + @property + def data(self): + t_start, t_stop, t_step = self._time.indices(self._file.n_frames) + time_slice = slice( + t_start * self._file.n_channels, + t_stop * self._file.n_channels, + t_step * self._file.n_channels, + ) + cols = self._file.layout.shape[1] + mask = np.array([(row - 1) * cols + (col - 1) for row, col in self.channels[()].ravel()]) + + if len(self._file.raw) < len(mask): + raise KeyError(f"You recorded less than 1 value per channel.") + + start, stop, step = time_slice.indices(len(self._file.raw)) + time_ind = np.tile(mask, ((stop - start) // step, 1)) + for i, time_sample in enumerate(range(start, stop, step)): + time_ind[i, :] += time_sample + time_ind = time_ind.reshape(-1) + data = np.empty(time_ind.shape) + for i in range(0, len(time_ind), 1000): + end_slice = i + 1000 + data[i:end_slice] = self._file.raw[time_ind[i:end_slice]] + return self._file.convert(data.reshape((len(mask), -1))) + + def _slice_index(self): + return + def _time_slice(self, instruction): - print("Requested time slice", instruction) + if isinstance(instruction, int): + instruction = slice(instruction, instruction + 1, 1) + if isinstance(instruction, slice): + prev_start, prev_stop, prev_step = self._time.indices(self._file.n_frames) + _len = (prev_stop - prev_start) // prev_step + this_start, this_stop, this_step = instruction.indices(_len) + start = prev_start + this_start * prev_step + stop = prev_start + this_stop * prev_step + step = this_step * prev_step + return _Slice(self._file, self._channels, slice(start, stop, step)) + def _channel_slice(self, instruction): - print("Requested channel slice", instruction) + return _Slice(self._file, self._channels[instruction], self._time) class BRWFile(File, _Slice): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __post_init__(self): _Slice.__init__(self, self) def _get_descr_prefix(self): @@ -157,15 +215,22 @@ def channels(self): return self['/3BRecInfo/3BMeaStreams/Raw/Chs'] @property - def data(self): - return self['/3BData/Raw'] + def n_channels(self): + return self.channels.shape[0] @property - def channels_layout(self): + def layout(self): + return self['/3BRecInfo/3BMeaChip/Layout'] + + @property + def raw(self): return self['/3BData/Raw'] class BXRFile(File): + def __post_init__(self): + pass + @property def channel_groups(self): return self.get_channel_groups() diff --git a/setup.py b/setup.py index 8bdee76..f8df8c2 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,6 @@ "Programming Language :: Python :: 3", "Operating System :: OS Independent", ], - install_requires=["h5py<3.0.0", "numpy"], + install_requires=["h5py", "numpy"], extras_require={"dev": ["sphinx", "furo"]}, ) diff --git a/tests/test_file.py b/tests/test_file.py index 59fa178..e34a8cd 100644 --- a/tests/test_file.py +++ b/tests/test_file.py @@ -116,4 +116,4 @@ class TestUserInfo(unittest.TestCase): def test_raw_user_info(self): with open_sample(samples.bxr, "r") as f: raw = f.get_raw_user_info() - self.assertIsNotNone(raw) + self.assertIsNotNone(raw) \ No newline at end of file diff --git a/tests/test_slicing.py b/tests/test_slicing.py new file mode 100644 index 0000000..c5e763c --- /dev/null +++ b/tests/test_slicing.py @@ -0,0 +1,50 @@ +import pathlib +from sre_constants import SRE_FLAG_UNICODE +import h5py +import unittest +import numpy +import bwpy +import numpy as np + +path = pathlib.Path(__file__).parent + +class TestFileObjects(unittest.TestCase): + def setUp(self): + super().setUp() + self.file = bwpy.File(f"{path}/test_samples/100frames.brw", "r") + + def tearDown(self) -> None: + super().tearDown() + self.file.close() + + def test_file_integrity(self): + self.assertEqual(self.file.data[()].shape, (self.file.n_channels, self.file.n_frames)) + + def test_basic_slicing(self): + self.assertEqual(self.file.raw.size, self.file.data.size) + print("get unsliced slice passed") + self.assertEqual((self.file.n_channels, 1), self.file.t[1].data.shape) + print("get channel slice passed") + self.assertEqual((1, self.file.n_frames), self.file.ch[1, 1].data.shape) + print("get time slice passed") + + def test_concatenate_slicing(self): + self.assertEqual((60 * 60, 90), self.file.t[:90].ch[0:60, 0:60].data.shape) + new_slice = self.file.t[:90].ch[:60, :60] + print("2 slices concat passed") + self.assertEqual((40 * 50, 70), new_slice.t[10:80].ch[:40, 10:].data.shape) + new_slice_2 = new_slice.t[10:80].ch[:40, :50] + print("4 slices concat passed") + self.assertEqual((1 * 50, 1), new_slice_2.t[65].ch[30, :].data.shape) + new_slice_3 = new_slice_2.t[65].ch[30, :] + print("6 slices concat passed") + self.assertEqual((1 * 50, 1), new_slice_3.t[:].ch[:].data.shape) + print("8 slices concat passed") + + def test_slicing_out_of_range_index(self): + self.assertEqual((64 * 2, 100), self.file.t[:150].ch[:850, -2:88880].data.shape) + + def test_empty_slice(self): + empty_slice = bwpy.File(f"{path}/test_samples/empty.brw", "r") + + self.assertEqual((4096, 0), empty_slice.t[:150].ch[:850, :].data.shape) From 0043887eb0076231a76ccfde84deb059ff794ea4 Mon Sep 17 00:00:00 2001 From: Igor10798 Date: Tue, 4 Oct 2022 17:41:15 +0200 Subject: [PATCH 03/19] more tests on slicing and docs --- docs/source/index.rst | 40 ++++++++++++++++++++++++++++++++++++++++ tests/__init__.py | 0 tests/test_slicing.py | 9 ++++----- 3 files changed, 44 insertions(+), 5 deletions(-) create mode 100644 tests/__init__.py diff --git a/docs/source/index.rst b/docs/source/index.rst index 0ee0969..93beb63 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -31,6 +31,46 @@ BWR and BXR files can be opened as a regular :class:`h5py.File` objects (see `Fi with bwpy("my_data.bwr", "r") as datafile: print(datafile.description) +Slicing +------- +The package allows slicing data of files of `.brw` format. +Slices are masks that can be applied to the data. + +Temporal slices can be obtained by calling the property `t`, which is an unidimensional `numpy` array: +.. code-block:: python + + import bwpy + + with bwpy("my_data.bwr", "r") as datafile: + datafile.t[0:10:2] + #will return the slice of the first 10 temporal recordings with a step of 2 + +Channel slices can be obtained by calling the porperty `ch`, which is bi-dimensional `numpy` array: +.. code-block:: python + + import bwpy + + with bwpy("my_data.bwr", "r") as datafile: + datafile.ch[0:10, 0: 10] + #will return the slice of the block of the first 10x10 channels + +Slices can be combined: +.. code-block:: python + + import bwpy + + with bwpy("my_data.bwr", "r") as datafile: + datafile.t[0:10].ch[0, 0] + #will return the slice of the first 10 temporal recordings of the first channel + +When the slicing is completed, it is possible to apply the mask to the data by calling `data` as it follows: +.. code-block:: python + + import bwpy + + with bwpy("my_data.bwr", "r") as datafile: + sliced_data = datafile.t[0:10].ch[0, 0].data + Indices and tables ================== diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_slicing.py b/tests/test_slicing.py index c5e763c..180610b 100644 --- a/tests/test_slicing.py +++ b/tests/test_slicing.py @@ -1,8 +1,5 @@ import pathlib -from sre_constants import SRE_FLAG_UNICODE -import h5py import unittest -import numpy import bwpy import numpy as np @@ -23,10 +20,12 @@ def test_file_integrity(self): def test_basic_slicing(self): self.assertEqual(self.file.raw.size, self.file.data.size) print("get unsliced slice passed") + self.assertEqual(True, np.allclose(self.file.t[0].data.reshape(-1), self.convert(self.file.n_channels))) self.assertEqual((self.file.n_channels, 1), self.file.t[1].data.shape) - print("get channel slice passed") - self.assertEqual((1, self.file.n_frames), self.file.ch[1, 1].data.shape) print("get time slice passed") + self.assertEqual(True, np.allclose(self.file.ch[0, 0].data.reshape(-1), self.convert(self.file.n_channels))) + self.assertEqual((1, self.file.n_frames), self.file.ch[1, 1].data.shape) + print("get channel slice passed") def test_concatenate_slicing(self): self.assertEqual((60 * 60, 90), self.file.t[:90].ch[0:60, 0:60].data.shape) From f7720bc82177c617fd623564f66e84d9b175211b Mon Sep 17 00:00:00 2001 From: Igor10798 Date: Tue, 11 Oct 2022 12:14:39 +0200 Subject: [PATCH 04/19] docs changes --- docs/source/index.rst | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index 93beb63..97ffc22 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -33,8 +33,7 @@ BWR and BXR files can be opened as a regular :class:`h5py.File` objects (see `Fi Slicing ------- -The package allows slicing data of files of `.brw` format. -Slices are masks that can be applied to the data. +The package allows slicing data of files of `.brw` format. Temporal slices can be obtained by calling the property `t`, which is an unidimensional `numpy` array: .. code-block:: python From 67efa269e67331886b54b06261d97d48e192f288 Mon Sep 17 00:00:00 2001 From: Igor10798 Date: Tue, 11 Oct 2022 13:03:08 +0200 Subject: [PATCH 05/19] test changes --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3cb791c..a250b59 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/psf/black - rev: 20.8b1 + rev: 22.10.0 hooks: - id: black language_version: python3 From a15c94b3083c70fd375d9e42aeb92c7598b0de39 Mon Sep 17 00:00:00 2001 From: Igor10798 Date: Wed, 12 Oct 2022 16:56:56 +0200 Subject: [PATCH 06/19] black formatting + test slicing eliminated prints --- bwpy/__init__.py | 33 ++++++++++-------- tests/test_file.py | 2 +- tests/test_slicing.py | 80 ++++++++++++++++++++++++++++++------------- 3 files changed, 76 insertions(+), 39 deletions(-) diff --git a/bwpy/__init__.py b/bwpy/__init__.py index 24e85c8..b219b1a 100644 --- a/bwpy/__init__.py +++ b/bwpy/__init__.py @@ -102,7 +102,9 @@ def get_raw_user_info(self): return self["3BUserInfo"] def convert(self, dv): - step_v = self.signal_inversion * ((self.max_volt - self.min_volt) / 2 ** self.bit_depth) + step_v = self.signal_inversion * ( + (self.max_volt - self.min_volt) / 2**self.bit_depth + ) v_offset = self.signal_inversion * self.min_volt return dv * step_v + v_offset @@ -123,14 +125,14 @@ def __init__(self, slice): class _TimeSlicer(_Slicer): def __getitem__(self, instruction): slice = self._slice._time_slice(instruction) - #self._file.n_frames = + # self._file.n_frames = return slice class _ChannelSlicer(_Slicer): def __getitem__(self, instruction): slice = self._slice._channel_slice(instruction) - #self._file.n_channel = + # self._file.n_channel = return slice @@ -153,7 +155,7 @@ def t(self): @functools.cache def ch(self): return _ChannelSlicer(self) - + @property def channels(self): return self._channels @@ -167,10 +169,12 @@ def data(self): t_step * self._file.n_channels, ) cols = self._file.layout.shape[1] - mask = np.array([(row - 1) * cols + (col - 1) for row, col in self.channels[()].ravel()]) + mask = np.array( + [(row - 1) * cols + (col - 1) for row, col in self.channels[()].ravel()] + ) if len(self._file.raw) < len(mask): - raise KeyError(f"You recorded less than 1 value per channel.") + raise ValueError(f"You recorded less than 1 value per channel.") start, stop, step = time_slice.indices(len(self._file.raw)) time_ind = np.tile(mask, ((stop - start) // step, 1)) @@ -182,10 +186,10 @@ def data(self): end_slice = i + 1000 data[i:end_slice] = self._file.raw[time_ind[i:end_slice]] return self._file.convert(data.reshape((len(mask), -1))) - + def _slice_index(self): - return - + return + def _time_slice(self, instruction): if isinstance(instruction, int): instruction = slice(instruction, instruction + 1, 1) @@ -198,10 +202,9 @@ def _time_slice(self, instruction): step = this_step * prev_step return _Slice(self._file, self._channels, slice(start, stop, step)) - def _channel_slice(self, instruction): return _Slice(self._file, self._channels[instruction], self._time) - + class BRWFile(File, _Slice): def __post_init__(self): @@ -212,7 +215,7 @@ def _get_descr_prefix(self): @property def channels(self): - return self['/3BRecInfo/3BMeaStreams/Raw/Chs'] + return self["/3BRecInfo/3BMeaStreams/Raw/Chs"] @property def n_channels(self): @@ -220,11 +223,11 @@ def n_channels(self): @property def layout(self): - return self['/3BRecInfo/3BMeaChip/Layout'] - + return self["/3BRecInfo/3BMeaChip/Layout"] + @property def raw(self): - return self['/3BData/Raw'] + return self["/3BData/Raw"] class BXRFile(File): diff --git a/tests/test_file.py b/tests/test_file.py index e34a8cd..59fa178 100644 --- a/tests/test_file.py +++ b/tests/test_file.py @@ -116,4 +116,4 @@ class TestUserInfo(unittest.TestCase): def test_raw_user_info(self): with open_sample(samples.bxr, "r") as f: raw = f.get_raw_user_info() - self.assertIsNotNone(raw) \ No newline at end of file + self.assertIsNotNone(raw) diff --git a/tests/test_slicing.py b/tests/test_slicing.py index 180610b..37e2791 100644 --- a/tests/test_slicing.py +++ b/tests/test_slicing.py @@ -5,45 +5,79 @@ path = pathlib.Path(__file__).parent + class TestFileObjects(unittest.TestCase): def setUp(self): super().setUp() self.file = bwpy.File(f"{path}/test_samples/100frames.brw", "r") - + def tearDown(self) -> None: super().tearDown() self.file.close() - + def test_file_integrity(self): - self.assertEqual(self.file.data[()].shape, (self.file.n_channels, self.file.n_frames)) - + self.assertEqual( + self.file.data[()].shape, (self.file.n_channels, self.file.n_frames) + ) + def test_basic_slicing(self): - self.assertEqual(self.file.raw.size, self.file.data.size) - print("get unsliced slice passed") - self.assertEqual(True, np.allclose(self.file.t[0].data.reshape(-1), self.convert(self.file.n_channels))) - self.assertEqual((self.file.n_channels, 1), self.file.t[1].data.shape) - print("get time slice passed") - self.assertEqual(True, np.allclose(self.file.ch[0, 0].data.reshape(-1), self.convert(self.file.n_channels))) - self.assertEqual((1, self.file.n_frames), self.file.ch[1, 1].data.shape) - print("get channel slice passed") + self.assertEqual( + self.file.raw.size, self.file.data.size, "get unsliced slice failed" + ) + self.assertEqual( + True, + np.allclose( + self.file.t[0].data.reshape(-1), + self.file.convert(self.file.raw[: self.file.n_channels]), + ), + "get time slice failed", + ) + self.assertEqual( + (self.file.n_channels, 1), self.file.t[1].data.shape, "get time slice failed" + ) + self.assertEqual( + True, + np.allclose( + self.file.ch[0, 0].data.reshape(-1), + self.file.convert(self.file.raw[:: self.file.n_channels]), + ), + "get channel slice failed", + ) + self.assertEqual( + (1, self.file.n_frames), + self.file.ch[1, 1].data.shape, + "get channel slice failed", + ) def test_concatenate_slicing(self): - self.assertEqual((60 * 60, 90), self.file.t[:90].ch[0:60, 0:60].data.shape) + self.assertEqual( + (60 * 60, 90), + self.file.t[:90].ch[0:60, 0:60].data.shape, + "2 slices concat failed", + ) new_slice = self.file.t[:90].ch[:60, :60] - print("2 slices concat passed") - self.assertEqual((40 * 50, 70), new_slice.t[10:80].ch[:40, 10:].data.shape) + self.assertEqual( + (40 * 50, 70), + new_slice.t[10:80].ch[:40, 10:].data.shape, + "4 slices concat failed", + ) new_slice_2 = new_slice.t[10:80].ch[:40, :50] - print("4 slices concat passed") - self.assertEqual((1 * 50, 1), new_slice_2.t[65].ch[30, :].data.shape) + self.assertEqual( + (1 * 50, 1), new_slice_2.t[65].ch[30, :].data.shape, "6 slices concat failed" + ) new_slice_3 = new_slice_2.t[65].ch[30, :] - print("6 slices concat passed") - self.assertEqual((1 * 50, 1), new_slice_3.t[:].ch[:].data.shape) - print("8 slices concat passed") + self.assertEqual( + (1 * 50, 1), new_slice_3.t[:].ch[:].data.shape, "8 slices concat failed" + ) def test_slicing_out_of_range_index(self): - self.assertEqual((64 * 2, 100), self.file.t[:150].ch[:850, -2:88880].data.shape) + self.assertEqual( + (64 * 2, 100), + self.file.t[:150].ch[:850, -2:88880].data.shape, + "slicing out of range index failed", + ) def test_empty_slice(self): empty_slice = bwpy.File(f"{path}/test_samples/empty.brw", "r") - - self.assertEqual((4096, 0), empty_slice.t[:150].ch[:850, :].data.shape) + with self.assertRaises(ValueError, "It should throw a ValueError") as context: + empty_slice.t[:150].ch[:850, :].data From e238215a039c4b3c2c4d5bbddf39e1e5dfa3e4ff Mon Sep 17 00:00:00 2001 From: Igor10798 <76614313+Igor10798@users.noreply.github.com> Date: Wed, 12 Oct 2022 16:58:57 +0200 Subject: [PATCH 07/19] docs suggestions Co-authored-by: Robin De Schepper --- docs/source/index.rst | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index 97ffc22..1bafd04 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -33,9 +33,8 @@ BWR and BXR files can be opened as a regular :class:`h5py.File` objects (see `Fi Slicing ------- -The package allows slicing data of files of `.brw` format. -Temporal slices can be obtained by calling the property `t`, which is an unidimensional `numpy` array: +The package allows you to slice the data in `.brw` files. The data can be restricted to certain time samples by indexing the `.t` property like a one-dimensional array: .. code-block:: python import bwpy @@ -44,7 +43,8 @@ Temporal slices can be obtained by calling the property `t`, which is an unidime datafile.t[0:10:2] #will return the slice of the first 10 temporal recordings with a step of 2 -Channel slices can be obtained by calling the porperty `ch`, which is bi-dimensional `numpy` array: +The data can be restricted to certain channels by indexing the `.ch` property like a two-dimensional array: + .. code-block:: python import bwpy @@ -53,7 +53,8 @@ Channel slices can be obtained by calling the porperty `ch`, which is bi-dimensi datafile.ch[0:10, 0: 10] #will return the slice of the block of the first 10x10 channels -Slices can be combined: +The obtained slices can themselves be sliced further: + .. code-block:: python import bwpy @@ -62,7 +63,8 @@ Slices can be combined: datafile.t[0:10].ch[0, 0] #will return the slice of the first 10 temporal recordings of the first channel -When the slicing is completed, it is possible to apply the mask to the data by calling `data` as it follows: +After slicing, the sliced data can be obtained by accessing the `data` property: + .. code-block:: python import bwpy From 1dd42b5de24454de750e7cc71136c128340e3aef Mon Sep 17 00:00:00 2001 From: Igor10798 Date: Wed, 12 Oct 2022 17:28:51 +0200 Subject: [PATCH 08/19] added examples in slicing function --- bwpy/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bwpy/__init__.py b/bwpy/__init__.py index b219b1a..676f265 100644 --- a/bwpy/__init__.py +++ b/bwpy/__init__.py @@ -200,6 +200,9 @@ def _time_slice(self, instruction): start = prev_start + this_start * prev_step stop = prev_start + this_stop * prev_step step = this_step * prev_step + # start: 1, end: 5, step: 1 --> Slice [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] into Slice [1, 2, 3, 4] + # start: 0, end: 7, step: 2 --> Slice [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] into Slice [0, 2, 4, 6] + # start: 5, end: -1, step: 4 --> Slice [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] into Slice [5, 9] return _Slice(self._file, self._channels, slice(start, stop, step)) def _channel_slice(self, instruction): From 01531460b25e5e206db697184f64371b314ed198 Mon Sep 17 00:00:00 2001 From: Igor10798 Date: Wed, 19 Oct 2022 16:42:21 +0200 Subject: [PATCH 09/19] Transformation added --- bwpy/signal.py | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 bwpy/signal.py diff --git a/bwpy/signal.py b/bwpy/signal.py new file mode 100644 index 0000000..30b81e7 --- /dev/null +++ b/bwpy/signal.py @@ -0,0 +1,54 @@ +import re +import abc +import functools +import numpy as np + + +__all__ = [] + + +def _transformer_factory(cls): + @functools.wraps(cls) + def transformer_factory(slice, *args, **kwargs): + return slice._transform(cls(*args, **kwargs)) + + return transformer_factory + + +class Transformer(abc.ABC): + def __init_subclass__(cls, operator=None, **kwargs) -> None: + super().__init_subclass__(**kwargs) + name = operator or re.sub("([a-z0-9])([A-Z])", r"\1_\2", cls.__name__).lower() + globals()[name] = _transformer_factory(cls) + __all__.append(name) + + @abc.abstractmethod + def __call__(self, data, file): + pass + + +class Variation(Transformer): + def __call__(self, data, file): + print("We are transforming our data using the", self.__class__.__name__) + try: + return np.amax(data, axis=2) - np.amin(data, axis=2) + except np.AxisError: + return data + + +class Amplitude(Transformer): + def __call__(self, data, file): + print("We are transforming our data using the", self.__class__.__name__) + try: + return np.amax(data, axis=2) + except np.AxisError: + return data + + +class Energy(Transformer): + def __call__(self, data, file): + print("We are transforming our data using the", self.__class__.__name__) + try: + return np.sum(np.square(data), axis=2) + except np.AxisError: + return data From 573ff0271aff4b94d90e9c3a64d7cb3be662e7c1 Mon Sep 17 00:00:00 2001 From: Igor10798 Date: Wed, 19 Oct 2022 16:44:36 +0200 Subject: [PATCH 10/19] transformation added to data --- bwpy/__init__.py | 49 +++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 44 insertions(+), 5 deletions(-) diff --git a/bwpy/__init__.py b/bwpy/__init__.py index 676f265..33c78be 100644 --- a/bwpy/__init__.py +++ b/bwpy/__init__.py @@ -136,6 +136,19 @@ def __getitem__(self, instruction): return slice +class Variation: + def __init__(self, bin_width): + self.offset = offset + + def __call__(self, data): + return data + self.offset + + +def variation(slice, *args, **kwargs): + slice.transformations.append(Variation(*args, **kwargs)) + return slice + + class _Slice: def __init__(self, file, channels=None, time=None): self._file = file @@ -145,6 +158,7 @@ def __init__(self, file, channels=None, time=None): if time is None: time = slice(None) self._time = time + self._transformations = [] @property @functools.cache @@ -185,10 +199,17 @@ def data(self): for i in range(0, len(time_ind), 1000): end_slice = i + 1000 data[i:end_slice] = self._file.raw[time_ind[i:end_slice]] - return self._file.convert(data.reshape((len(mask), -1))) - def _slice_index(self): - return + data = self._file.convert(data.reshape((len(mask), -1))) + for transformation in self._transformations: + try: + data = transformation(data, self._file) + except Exception as e: + raise TransformationError( + f"Error in transformation pipeline {self._transformations}" + f", with {transformation}: {e}" + ) + return data def _time_slice(self, instruction): if isinstance(instruction, int): @@ -203,10 +224,24 @@ def _time_slice(self, instruction): # start: 1, end: 5, step: 1 --> Slice [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] into Slice [1, 2, 3, 4] # start: 0, end: 7, step: 2 --> Slice [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] into Slice [0, 2, 4, 6] # start: 5, end: -1, step: 4 --> Slice [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] into Slice [5, 9] - return _Slice(self._file, self._channels, slice(start, stop, step)) + ret = self._copy_slice() + ret._time = slice(start, stop, step) + return ret def _channel_slice(self, instruction): - return _Slice(self._file, self._channels[instruction], self._time) + ret = self._copy_slice() + ret._channels = self._channels[instruction] + return ret + + def _transform(self, transformation): + ret = self._copy_slice() + ret._transformations.append(transformation) + return ret + + def _copy_slice(self): + copied = _Slice(self._file, self._channels, self._time) + copied._transformations = self._transformations.copy() + return copied class BRWFile(File, _Slice): @@ -270,4 +305,8 @@ def get_channel_group(self, group_id): return ChannelGroup._from_bxr(self, data) +class TransformationError(Exception): + pass + + __all__ = ["File", "BRWFile", "BXRFile"] From 342a49ae229bb557a36f625f94c50423728fbc08 Mon Sep 17 00:00:00 2001 From: Igor10798 Date: Tue, 25 Oct 2022 15:12:20 +0200 Subject: [PATCH 11/19] fixed slicing shaping, mea viewer class, --- bwpy/__init__.py | 6 +++- bwpy/mea_viewer.py | 66 +++++++++++++++++++++++++++++++++++++++++++ bwpy/signal.py | 18 ++++++++---- tests/test_slicing.py | 9 +++--- 4 files changed, 89 insertions(+), 10 deletions(-) create mode 100644 bwpy/mea_viewer.py diff --git a/bwpy/__init__.py b/bwpy/__init__.py index 33c78be..e4849c6 100644 --- a/bwpy/__init__.py +++ b/bwpy/__init__.py @@ -200,7 +200,11 @@ def data(self): end_slice = i + 1000 data[i:end_slice] = self._file.raw[time_ind[i:end_slice]] - data = self._file.convert(data.reshape((len(mask), -1))) + digital = data.reshape((len(mask), -1)) + analog = self._file.convert(digital) + data = analog.reshape((-1, *self._file.layout.shape)) + data = np.flip(np.rot90(data.swapaxes(2, 0), -1), 1) + for transformation in self._transformations: try: data = transformation(data, self._file) diff --git a/bwpy/mea_viewer.py b/bwpy/mea_viewer.py new file mode 100644 index 0000000..6e973fe --- /dev/null +++ b/bwpy/mea_viewer.py @@ -0,0 +1,66 @@ +from .signal import * +import abc +import plotly.graph_objects as go +import numpy as np + + +class Viewer(abc.ABC): + @abc.abstractmethod + def set_up_view(self, slice, bin_size, view_method): + pass + + +class MEAViewer(Viewer): + colorscale = [[0, "#ebe834"], [1.0, "#eb4034"]] + min_value = 0 + max_value = 170 + + def set_up_view(self, slice, view_method, bin_size=100): + sliced_data = slice.data + fig = go.Figure() + for i in range(0, sliced_data.shape[1]): + if sliced_data.shape[1] - i < bin: + break + transform = getattr(signal, view_method) + bin = self.ms_to_idx(slice, bin_size) + signal = transform(slice._file.t[i : i + bin]).data + fig.add_trace( + go.Heatmap( + visible=False, + z=signal, + zmin=self.min_value, + zmax=self.max_value, + colorscale=self.colorscale, + ) + ) + return fig + + def ms_to_idx(self, slice, bin_size): + return slice._file.n_frames * slice._file.sampling_rate, bin_size + + def format_plot(self, fig): + # Create and add slider + fig.data[0].visible = True + steps = [] + for i in range(len(fig.data)): + step = dict( + method="update", + args=[ + {"visible": [False] * len(fig.data)}, + {"title": "Time slice: " + str(i)}, + ], # layout attribute + ) + step["args"][0]["visible"][i] = True # Toggle i'th trace to "visible" + steps.append(step) + + sliders = [ + dict(active=0, currentvalue={"prefix": "Time: "}, pad={"t": 50}, steps=steps) + ] + + fig.update_layout( + yaxis=dict(scaleanchor="x", autorange="reversed"), sliders=sliders + ) + return fig + + def show(self, fig): + fig.show() diff --git a/bwpy/signal.py b/bwpy/signal.py index 30b81e7..774c84f 100644 --- a/bwpy/signal.py +++ b/bwpy/signal.py @@ -31,7 +31,9 @@ class Variation(Transformer): def __call__(self, data, file): print("We are transforming our data using the", self.__class__.__name__) try: - return np.amax(data, axis=2) - np.amin(data, axis=2) + return np.amax( + data.reshape((file.layout.shape[0], file.layout.shape[1], -1)), axis=2 + ) - np.amin(data, axis=2) except np.AxisError: return data @@ -39,9 +41,12 @@ def __call__(self, data, file): class Amplitude(Transformer): def __call__(self, data, file): print("We are transforming our data using the", self.__class__.__name__) - try: - return np.amax(data, axis=2) - except np.AxisError: + print("input data:", data.shape) + if data.ndim < 3: + return data + else: + data = np.max(np.abs(data), axis=2) + print("out data:", data.shape) return data @@ -49,6 +54,9 @@ class Energy(Transformer): def __call__(self, data, file): print("We are transforming our data using the", self.__class__.__name__) try: - return np.sum(np.square(data), axis=2) + return np.sum( + np.square(data.reshape((file.layout.shape[0], file.layout.shape[1], -1))), + axis=2, + ) except np.AxisError: return data diff --git a/tests/test_slicing.py b/tests/test_slicing.py index 37e2791..81b9305 100644 --- a/tests/test_slicing.py +++ b/tests/test_slicing.py @@ -17,7 +17,7 @@ def tearDown(self) -> None: def test_file_integrity(self): self.assertEqual( - self.file.data[()].shape, (self.file.n_channels, self.file.n_frames) + self.file.data[()].shape, (*self.file.layout, self.file.n_frames) ) def test_basic_slicing(self): @@ -33,7 +33,7 @@ def test_basic_slicing(self): "get time slice failed", ) self.assertEqual( - (self.file.n_channels, 1), self.file.t[1].data.shape, "get time slice failed" + (*self.file.layout, 1), self.file.t[1].data.shape, "get time slice failed" ) self.assertEqual( True, @@ -51,7 +51,7 @@ def test_basic_slicing(self): def test_concatenate_slicing(self): self.assertEqual( - (60 * 60, 90), + (60, 60, 90), self.file.t[:90].ch[0:60, 0:60].data.shape, "2 slices concat failed", ) @@ -72,12 +72,13 @@ def test_concatenate_slicing(self): def test_slicing_out_of_range_index(self): self.assertEqual( - (64 * 2, 100), + (64, 2, 100), self.file.t[:150].ch[:850, -2:88880].data.shape, "slicing out of range index failed", ) def test_empty_slice(self): empty_slice = bwpy.File(f"{path}/test_samples/empty.brw", "r") + empty_slice.t[:150].ch[:850, :].data with self.assertRaises(ValueError, "It should throw a ValueError") as context: empty_slice.t[:150].ch[:850, :].data From b492a12410ab0ef554068c22a247ffad2679f211 Mon Sep 17 00:00:00 2001 From: Igor10798 Date: Sat, 12 Nov 2022 16:02:38 +0100 Subject: [PATCH 12/19] artifact detection --- bwpy/__init__.py | 16 +++++----- bwpy/mea_viewer.py | 27 ++++++++--------- bwpy/signal.py | 73 +++++++++++++++++++++++++++++++--------------- 3 files changed, 72 insertions(+), 44 deletions(-) diff --git a/bwpy/__init__.py b/bwpy/__init__.py index e4849c6..52c2d10 100644 --- a/bwpy/__init__.py +++ b/bwpy/__init__.py @@ -125,21 +125,16 @@ def __init__(self, slice): class _TimeSlicer(_Slicer): def __getitem__(self, instruction): slice = self._slice._time_slice(instruction) - # self._file.n_frames = return slice class _ChannelSlicer(_Slicer): def __getitem__(self, instruction): slice = self._slice._channel_slice(instruction) - # self._file.n_channel = return slice class Variation: - def __init__(self, bin_width): - self.offset = offset - def __call__(self, data): return data + self.offset @@ -159,6 +154,7 @@ def __init__(self, file, channels=None, time=None): time = slice(None) self._time = time self._transformations = [] + self.bin_size = 100 @property @functools.cache @@ -202,12 +198,17 @@ def data(self): digital = data.reshape((len(mask), -1)) analog = self._file.convert(digital) + # Shape (-1, row, cols) because in the next line we'll swapaxes. + # If we used directly shape (row,cols,-1) we would end up with data[:,:,0]: + # [0,3,6] [0,1,2] + # [1,4,7] instead of [3,4,5] + # [2,5,8] [6,7,8] data = analog.reshape((-1, *self._file.layout.shape)) data = np.flip(np.rot90(data.swapaxes(2, 0), -1), 1) for transformation in self._transformations: try: - data = transformation(data, self._file) + data = transformation(data, self._file, self.bin_size) except Exception as e: raise TransformationError( f"Error in transformation pipeline {self._transformations}" @@ -237,9 +238,10 @@ def _channel_slice(self, instruction): ret._channels = self._channels[instruction] return ret - def _transform(self, transformation): + def _transform(self, transformation, bin_size): ret = self._copy_slice() ret._transformations.append(transformation) + self.bin_size = bin_size return ret def _copy_slice(self): diff --git a/bwpy/mea_viewer.py b/bwpy/mea_viewer.py index 6e973fe..3c81f38 100644 --- a/bwpy/mea_viewer.py +++ b/bwpy/mea_viewer.py @@ -1,4 +1,4 @@ -from .signal import * +from . import signal import abc import plotly.graph_objects as go import numpy as np @@ -6,28 +6,29 @@ class Viewer(abc.ABC): @abc.abstractmethod - def set_up_view(self, slice, bin_size, view_method): + def build_view(self, slice, bin_size, view_method): pass class MEAViewer(Viewer): colorscale = [[0, "#ebe834"], [1.0, "#eb4034"]] min_value = 0 - max_value = 170 + max_value = 0 - def set_up_view(self, slice, view_method, bin_size=100): - sliced_data = slice.data + def build_view(self, slice, view_method="amplitude", bin_size=100): + self.min_value = slice._file.convert(slice._file.min_volt) + self.max_value = slice._file.convert(slice._file.max_volt) + # min = -12433, max = 4183 + self.min_value = 0 + self.max_value = 170 fig = go.Figure() - for i in range(0, sliced_data.shape[1]): - if sliced_data.shape[1] - i < bin: - break - transform = getattr(signal, view_method) - bin = self.ms_to_idx(slice, bin_size) - signal = transform(slice._file.t[i : i + bin]).data + apply_transformation = getattr(signal, view_method) + signals = apply_transformation(slice, bin_size).data + for signal_frame in signals: fig.add_trace( go.Heatmap( visible=False, - z=signal, + z=signal_frame, zmin=self.min_value, zmax=self.max_value, colorscale=self.colorscale, @@ -36,7 +37,7 @@ def set_up_view(self, slice, view_method, bin_size=100): return fig def ms_to_idx(self, slice, bin_size): - return slice._file.n_frames * slice._file.sampling_rate, bin_size + return slice._file.n_frames / slice._file.sampling_rate, bin_size def format_plot(self, fig): # Create and add slider diff --git a/bwpy/signal.py b/bwpy/signal.py index 774c84f..f97f8f8 100644 --- a/bwpy/signal.py +++ b/bwpy/signal.py @@ -2,6 +2,8 @@ import abc import functools import numpy as np +import numpy.lib.stride_tricks as np_tricks +from scipy.signal import find_peaks __all__ = [] @@ -9,8 +11,8 @@ def _transformer_factory(cls): @functools.wraps(cls) - def transformer_factory(slice, *args, **kwargs): - return slice._transform(cls(*args, **kwargs)) + def transformer_factory(slice, bin_size, *args, **kwargs): + return slice._transform(cls(*args, **kwargs), bin_size) return transformer_factory @@ -22,41 +24,64 @@ def __init_subclass__(cls, operator=None, **kwargs) -> None: globals()[name] = _transformer_factory(cls) __all__.append(name) + def _apply_window(self, data, bin_size): + rows = data.shape[0] + cols = data.shape[1] + window_shape = (rows, cols, bin_size) + # sliding window returns more complex shape like (num_windows, 1, 1, rows, cols, bin_size) + # with the reshape we get rid of the unnecessary complexity (1, 1) + return np_tricks.sliding_window_view(data, window_shape).reshape( + -1, rows, cols, bin_size + ) + @abc.abstractmethod - def __call__(self, data, file): + def __call__(self, data, file, bin_size): pass class Variation(Transformer): - def __call__(self, data, file): - print("We are transforming our data using the", self.__class__.__name__) - try: - return np.amax( - data.reshape((file.layout.shape[0], file.layout.shape[1], -1)), axis=2 - ) - np.amin(data, axis=2) - except np.AxisError: + def __call__(self, data, file, bin_size): + print("calling ", self.__class__.__name__) + if data.ndim < 3: return data + else: + windows = self._apply_window(data, bin_size) + return np.max(np.abs(windows), axis=3) - np.min(np.abs(windows), axis=3) class Amplitude(Transformer): - def __call__(self, data, file): - print("We are transforming our data using the", self.__class__.__name__) - print("input data:", data.shape) + def __call__(self, data, file, bin_size): + print("calling ", self.__class__.__name__) if data.ndim < 3: return data else: - data = np.max(np.abs(data), axis=2) - print("out data:", data.shape) - return data + windows = self._apply_window(data, bin_size) + return np.max(np.abs(windows), axis=3) class Energy(Transformer): - def __call__(self, data, file): - print("We are transforming our data using the", self.__class__.__name__) - try: - return np.sum( - np.square(data.reshape((file.layout.shape[0], file.layout.shape[1], -1))), - axis=2, - ) - except np.AxisError: + def __call__(self, data, file, bin_size): + print("calling ", self.__class__.__name__) + if data.ndim < 3: return data + else: + windows = self._apply_window(data, bin_size) + return np.sum(np.square(windows), axis=3) + + +class Raw(Transformer): + def __call__(self, data, file, bin_size): + print("calling ", self.__class__.__name__) + return np.moveaxis(data, 2, 0) + + +class DetectArtifacts(Transformer): + def __call__(self, data, file, bin_size): + n_channels = file.layout.shape[0] * file.layout.shape[1] + artifacts = [] + out_bounds = data > 170 + + for i in range(len(data)): + if np.count_nonzero(out_bounds[i]) / n_channels > 0.8: + artifacts.append(i) + return artifacts From ae3e7709582838d93a1fbebd1766c5d0ef12d7ab Mon Sep 17 00:00:00 2001 From: Igor10798 Date: Wed, 16 Nov 2022 09:46:15 +0100 Subject: [PATCH 13/19] created shutter, fixed architecture --- bwpy/__init__.py | 5 ++-- bwpy/mea_viewer.py | 16 ++++++----- bwpy/signal.py | 69 ++++++++++++++++++++++++++++++++-------------- 3 files changed, 60 insertions(+), 30 deletions(-) diff --git a/bwpy/__init__.py b/bwpy/__init__.py index 52c2d10..5a810f8 100644 --- a/bwpy/__init__.py +++ b/bwpy/__init__.py @@ -208,7 +208,7 @@ def data(self): for transformation in self._transformations: try: - data = transformation(data, self._file, self.bin_size) + data = transformation(data, self, self._file) except Exception as e: raise TransformationError( f"Error in transformation pipeline {self._transformations}" @@ -238,10 +238,9 @@ def _channel_slice(self, instruction): ret._channels = self._channels[instruction] return ret - def _transform(self, transformation, bin_size): + def _transform(self, transformation): ret = self._copy_slice() ret._transformations.append(transformation) - self.bin_size = bin_size return ret def _copy_slice(self): diff --git a/bwpy/mea_viewer.py b/bwpy/mea_viewer.py index 3c81f38..97a0ec9 100644 --- a/bwpy/mea_viewer.py +++ b/bwpy/mea_viewer.py @@ -2,6 +2,7 @@ import abc import plotly.graph_objects as go import numpy as np +from scipy.stats import norm class Viewer(abc.ABC): @@ -16,11 +17,6 @@ class MEAViewer(Viewer): max_value = 0 def build_view(self, slice, view_method="amplitude", bin_size=100): - self.min_value = slice._file.convert(slice._file.min_volt) - self.max_value = slice._file.convert(slice._file.max_volt) - # min = -12433, max = 4183 - self.min_value = 0 - self.max_value = 170 fig = go.Figure() apply_transformation = getattr(signal, view_method) signals = apply_transformation(slice, bin_size).data @@ -30,15 +26,21 @@ def build_view(self, slice, view_method="amplitude", bin_size=100): visible=False, z=signal_frame, zmin=self.min_value, - zmax=self.max_value, + zmax=self.get_up_bound(signals, slice._file), colorscale=self.colorscale, ) ) - return fig + return self.format_plot(fig) def ms_to_idx(self, slice, bin_size): return slice._file.n_frames / slice._file.sampling_rate, bin_size + def get_up_bound(self, data, file): + up_limit = file.convert(file.max_volt) * 0.98 + no_artifacts = data[data < up_limit] + mu, sd = norm.fit(no_artifacts.reshape(-1)) + return mu + 2 * sd + def format_plot(self, fig): # Create and add slider fig.data[0].visible = True diff --git a/bwpy/signal.py b/bwpy/signal.py index f97f8f8..bb35435 100644 --- a/bwpy/signal.py +++ b/bwpy/signal.py @@ -3,7 +3,6 @@ import functools import numpy as np import numpy.lib.stride_tricks as np_tricks -from scipy.signal import find_peaks __all__ = [] @@ -11,8 +10,8 @@ def _transformer_factory(cls): @functools.wraps(cls) - def transformer_factory(slice, bin_size, *args, **kwargs): - return slice._transform(cls(*args, **kwargs), bin_size) + def transformer_factory(slice, *args, **kwargs): + return slice._transform(cls(*args, **kwargs)) return transformer_factory @@ -35,53 +34,83 @@ def _apply_window(self, data, bin_size): ) @abc.abstractmethod - def __call__(self, data, file, bin_size): + def __call__(self, data, file): pass class Variation(Transformer): - def __call__(self, data, file, bin_size): + def __init__(self, bin_size): + self.bin_size = bin_size + + def __call__(self, data, slice, file): print("calling ", self.__class__.__name__) if data.ndim < 3: return data else: - windows = self._apply_window(data, bin_size) + windows = self._apply_window(data, self.bin_size) return np.max(np.abs(windows), axis=3) - np.min(np.abs(windows), axis=3) class Amplitude(Transformer): - def __call__(self, data, file, bin_size): + def __init__(self, bin_size): + self.bin_size = bin_size + + def __call__(self, data, slice, file): print("calling ", self.__class__.__name__) if data.ndim < 3: return data else: - windows = self._apply_window(data, bin_size) + windows = self._apply_window(data, self.bin_size) return np.max(np.abs(windows), axis=3) class Energy(Transformer): - def __call__(self, data, file, bin_size): + def __init__(self, bin_size): + self.bin_size = bin_size + + def __call__(self, data, slice, file): print("calling ", self.__class__.__name__) if data.ndim < 3: return data else: - windows = self._apply_window(data, bin_size) + windows = self._apply_window(data, self.bin_size) return np.sum(np.square(windows), axis=3) class Raw(Transformer): - def __call__(self, data, file, bin_size): + def __call__(self, data, slice, file): print("calling ", self.__class__.__name__) return np.moveaxis(data, 2, 0) class DetectArtifacts(Transformer): - def __call__(self, data, file, bin_size): - n_channels = file.layout.shape[0] * file.layout.shape[1] - artifacts = [] - out_bounds = data > 170 - - for i in range(len(data)): - if np.count_nonzero(out_bounds[i]) / n_channels > 0.8: - artifacts.append(i) - return artifacts + def __call__(self, data, slice, file): + print("calling ", self.__class__.__name__) + if data.ndim < 3: + return data + else: + up_limit = file.convert(file.max_volt) * 0.98 + out_bounds = data > up_limit + mask = np.sum(out_bounds, axis=(1, 2)) > 80 + return mask + + +class Shutter(Transformer): + def __init__(self, data, delay_ms): + self.delay_ms = delay_ms + self.data = data + + def __call__(self, mask, slice, file): + print("calling ", self.__class__.__name__) + if mask.ndim > 1: + raise ValueError("mask must be a 1dim array.") + + delay = self.ms_to_idx(file, self.delay_ms) + for i in range(len(mask) - 1, 0, -1): + if mask[i]: + mask[i : i + delay] = 1 + + return self.data[mask] + + def ms_to_idx(self, file, delay_ms): + return int(delay_ms * file.sampling_rate * 1000) From 11633c4c5d7ece27df449b9ea6f15bfbc0b03f82 Mon Sep 17 00:00:00 2001 From: Igor10798 Date: Thu, 17 Nov 2022 15:51:19 +0100 Subject: [PATCH 14/19] changed slice.data structure from [rows, cols, frames] to [frames, rows, cols] --- bwpy/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bwpy/__init__.py b/bwpy/__init__.py index 5a810f8..f66c950 100644 --- a/bwpy/__init__.py +++ b/bwpy/__init__.py @@ -204,7 +204,8 @@ def data(self): # [1,4,7] instead of [3,4,5] # [2,5,8] [6,7,8] data = analog.reshape((-1, *self._file.layout.shape)) - data = np.flip(np.rot90(data.swapaxes(2, 0), -1), 1) + # in the line above we have [frame, rows, cols], with the line below we get [rows, cols, frame] + # data = np.flip(np.rot90(data.swapaxes(2, 0), -1), 1) for transformation in self._transformations: try: From c6ae3322ae670af15393a13934fd0f896951f8ae Mon Sep 17 00:00:00 2001 From: Igor10798 Date: Thu, 17 Nov 2022 15:51:53 +0100 Subject: [PATCH 15/19] meaviewer now accepts slice or np.arrays to display --- bwpy/mea_viewer.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/bwpy/mea_viewer.py b/bwpy/mea_viewer.py index 97a0ec9..0af5f7d 100644 --- a/bwpy/mea_viewer.py +++ b/bwpy/mea_viewer.py @@ -16,17 +16,25 @@ class MEAViewer(Viewer): min_value = 0 max_value = 0 - def build_view(self, slice, view_method="amplitude", bin_size=100): + def build_view( + self, file, slice=None, view_method="amplitude", bin_size=100, data=None + ): fig = go.Figure() - apply_transformation = getattr(signal, view_method) - signals = apply_transformation(slice, bin_size).data + if slice and data: + raise ValueError("slice and data arguments are mutually exclusives.") + if slice: + apply_transformation = getattr(signal, view_method) + signals = apply_transformation(slice, bin_size).data + else: + signals = data + for signal_frame in signals: fig.add_trace( go.Heatmap( visible=False, z=signal_frame, zmin=self.min_value, - zmax=self.get_up_bound(signals, slice._file), + zmax=self.get_up_bound(signals, file), colorscale=self.colorscale, ) ) From 2bc15db3977eb66e72b3cca1f6b9dd632886a3ae Mon Sep 17 00:00:00 2001 From: Igor10798 Date: Thu, 17 Nov 2022 15:52:33 +0100 Subject: [PATCH 16/19] shutter now automatically displays artifacts. added callback to shutter --- bwpy/signal.py | 35 ++++++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/bwpy/signal.py b/bwpy/signal.py index bb35435..166faf8 100644 --- a/bwpy/signal.py +++ b/bwpy/signal.py @@ -1,3 +1,4 @@ +from . import mea_viewer import re import abc import functools @@ -24,13 +25,13 @@ def __init_subclass__(cls, operator=None, **kwargs) -> None: __all__.append(name) def _apply_window(self, data, bin_size): - rows = data.shape[0] - cols = data.shape[1] - window_shape = (rows, cols, bin_size) + rows = data.shape[1] + cols = data.shape[2] + window_shape = (bin_size, rows, cols) # sliding window returns more complex shape like (num_windows, 1, 1, rows, cols, bin_size) # with the reshape we get rid of the unnecessary complexity (1, 1) return np_tricks.sliding_window_view(data, window_shape).reshape( - -1, rows, cols, bin_size + -1, bin_size, rows, cols ) @abc.abstractmethod @@ -48,7 +49,7 @@ def __call__(self, data, slice, file): return data else: windows = self._apply_window(data, self.bin_size) - return np.max(np.abs(windows), axis=3) - np.min(np.abs(windows), axis=3) + return np.max(np.abs(windows), axis=1) - np.min(np.abs(windows), axis=1) class Amplitude(Transformer): @@ -61,7 +62,7 @@ def __call__(self, data, slice, file): return data else: windows = self._apply_window(data, self.bin_size) - return np.max(np.abs(windows), axis=3) + return np.max(np.abs(windows), axis=1) class Energy(Transformer): @@ -74,7 +75,7 @@ def __call__(self, data, slice, file): return data else: windows = self._apply_window(data, self.bin_size) - return np.sum(np.square(windows), axis=3) + return np.sum(np.square(windows), axis=1) class Raw(Transformer): @@ -83,6 +84,12 @@ def __call__(self, data, slice, file): return np.moveaxis(data, 2, 0) +class NoMethod(Transformer): + def __call__(self, data, slice, file): + print("calling ", self.__class__.__name__) + return data + + class DetectArtifacts(Transformer): def __call__(self, data, slice, file): print("calling ", self.__class__.__name__) @@ -96,9 +103,10 @@ def __call__(self, data, slice, file): class Shutter(Transformer): - def __init__(self, data, delay_ms): + def __init__(self, data, delay_ms, callable=None): self.delay_ms = delay_ms self.data = data + self.callable = callable def __call__(self, mask, slice, file): print("calling ", self.__class__.__name__) @@ -109,8 +117,17 @@ def __call__(self, mask, slice, file): for i in range(len(mask) - 1, 0, -1): if mask[i]: mask[i : i + delay] = 1 + masked_data = self.data[mask] + + bin_size = self.ms_to_idx(file, 100) + mea_viewer.MEAViewer().build_view( + file, data=masked_data, view_method="no_method", bin_size=bin_size + ).show() - return self.data[mask] + if self.callable: + return self.callable(self.data, mask) + else: + return masked_data def ms_to_idx(self, file, delay_ms): return int(delay_ms * file.sampling_rate * 1000) From e62eda76fc836ab1ed7863e90b678af4f58da9c6 Mon Sep 17 00:00:00 2001 From: Igor10798 <76614313+Igor10798@users.noreply.github.com> Date: Fri, 18 Nov 2022 10:19:28 +0100 Subject: [PATCH 17/19] Update bwpy/__init__.py Co-authored-by: Robin De Schepper --- bwpy/__init__.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/bwpy/__init__.py b/bwpy/__init__.py index f66c950..c33f811 100644 --- a/bwpy/__init__.py +++ b/bwpy/__init__.py @@ -31,11 +31,12 @@ def _establish_type(self): if self.description.startswith("BRW"): self.__class__ = BRWFile self._type = "brw" - self.__post_init__() - if self.description.startswith("BXR"): + elif self.description.startswith("BXR"): self.__class__ = BXRFile self._type = "bxr" - self.__post_init__() + else: + raise IOError("File is not in BXR/BRW format") + self.__post_init__() @property def description(self): From 14157411b2320dc685f198a6f1a8aa893a99aade Mon Sep 17 00:00:00 2001 From: Igor10798 <76614313+Igor10798@users.noreply.github.com> Date: Fri, 18 Nov 2022 10:21:11 +0100 Subject: [PATCH 18/19] Apply suggestions from code review Co-authored-by: Robin De Schepper --- bwpy/signal.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/bwpy/signal.py b/bwpy/signal.py index 166faf8..ce71e49 100644 --- a/bwpy/signal.py +++ b/bwpy/signal.py @@ -57,7 +57,6 @@ def __init__(self, bin_size): self.bin_size = bin_size def __call__(self, data, slice, file): - print("calling ", self.__class__.__name__) if data.ndim < 3: return data else: @@ -70,7 +69,6 @@ def __init__(self, bin_size): self.bin_size = bin_size def __call__(self, data, slice, file): - print("calling ", self.__class__.__name__) if data.ndim < 3: return data else: @@ -80,19 +78,17 @@ def __call__(self, data, slice, file): class Raw(Transformer): def __call__(self, data, slice, file): - print("calling ", self.__class__.__name__) return np.moveaxis(data, 2, 0) -class NoMethod(Transformer): +class Noop(Transformer): + """Noop that doesn't transform the data at all.""" def __call__(self, data, slice, file): - print("calling ", self.__class__.__name__) return data class DetectArtifacts(Transformer): def __call__(self, data, slice, file): - print("calling ", self.__class__.__name__) if data.ndim < 3: return data else: @@ -109,7 +105,6 @@ def __init__(self, data, delay_ms, callable=None): self.callable = callable def __call__(self, mask, slice, file): - print("calling ", self.__class__.__name__) if mask.ndim > 1: raise ValueError("mask must be a 1dim array.") From 030ce99da50036906234be05ceec04e04649abee Mon Sep 17 00:00:00 2001 From: Igor10798 Date: Fri, 18 Nov 2022 11:21:33 +0100 Subject: [PATCH 19/19] deleted raw (no necessary because slice now have the same structure of signals) added WindowedTransformer --- bwpy/mea_viewer.py | 25 +++++++++--------- bwpy/signal.py | 66 +++++++++++++++++++++++----------------------- 2 files changed, 46 insertions(+), 45 deletions(-) diff --git a/bwpy/mea_viewer.py b/bwpy/mea_viewer.py index 0af5f7d..b246cf1 100644 --- a/bwpy/mea_viewer.py +++ b/bwpy/mea_viewer.py @@ -1,48 +1,49 @@ from . import signal import abc -import plotly.graph_objects as go -import numpy as np from scipy.stats import norm class Viewer(abc.ABC): @abc.abstractmethod - def build_view(self, slice, bin_size, view_method): + def build_view(self, slice, wiew_method, window_size): pass class MEAViewer(Viewer): colorscale = [[0, "#ebe834"], [1.0, "#eb4034"]] - min_value = 0 - max_value = 0 def build_view( - self, file, slice=None, view_method="amplitude", bin_size=100, data=None + self, file, slice=None, view_method="amplitude", window_size=100, data=None ): + try: + import plotly.graph_objects as go + except: + raise ModuleNotFoundError( + "You have to install plotly in order to use MEAViewer." + ) + fig = go.Figure() if slice and data: raise ValueError("slice and data arguments are mutually exclusives.") if slice: apply_transformation = getattr(signal, view_method) - signals = apply_transformation(slice, bin_size).data + signals = apply_transformation(slice, window_size).data else: signals = data + max_val = self.get_up_bound(signals, file) for signal_frame in signals: fig.add_trace( go.Heatmap( visible=False, z=signal_frame, - zmin=self.min_value, - zmax=self.get_up_bound(signals, file), + zmin=0, + zmax=max_val, colorscale=self.colorscale, ) ) return self.format_plot(fig) - def ms_to_idx(self, slice, bin_size): - return slice._file.n_frames / slice._file.sampling_rate, bin_size - def get_up_bound(self, data, file): up_limit = file.convert(file.max_volt) * 0.98 no_artifacts = data[data < up_limit] diff --git a/bwpy/signal.py b/bwpy/signal.py index 166faf8..24d7e86 100644 --- a/bwpy/signal.py +++ b/bwpy/signal.py @@ -24,76 +24,76 @@ def __init_subclass__(cls, operator=None, **kwargs) -> None: globals()[name] = _transformer_factory(cls) __all__.append(name) - def _apply_window(self, data, bin_size): + @abc.abstractmethod + def __call__(self, data, file): + pass + + +class WindowedTransformer(Transformer): + def get_signal_window(self, data, window_size): rows = data.shape[1] cols = data.shape[2] - window_shape = (bin_size, rows, cols) - # sliding window returns more complex shape like (num_windows, 1, 1, rows, cols, bin_size) + window_shape = (window_size, rows, cols) + # sliding window returns more complex shape like (num_windows, 1, 1, rows, cols, window_size) # with the reshape we get rid of the unnecessary complexity (1, 1) return np_tricks.sliding_window_view(data, window_shape).reshape( - -1, bin_size, rows, cols + -1, window_size, rows, cols ) - @abc.abstractmethod - def __call__(self, data, file): - pass - -class Variation(Transformer): - def __init__(self, bin_size): - self.bin_size = bin_size +class Variation(WindowedTransformer): + def __init__(self, window_size): + self.window_size = window_size def __call__(self, data, slice, file): print("calling ", self.__class__.__name__) - if data.ndim < 3: + # If data have only 2 dimensions windowing is not necessary + if data.ndim == 2: return data else: - windows = self._apply_window(data, self.bin_size) + windows = self.get_signal_window(data, self.window_size) return np.max(np.abs(windows), axis=1) - np.min(np.abs(windows), axis=1) -class Amplitude(Transformer): - def __init__(self, bin_size): - self.bin_size = bin_size +class Amplitude(WindowedTransformer): + def __init__(self, window_size): + self.window_size = window_size def __call__(self, data, slice, file): print("calling ", self.__class__.__name__) - if data.ndim < 3: + # If data have only 2 dimensions windowing is not necessary + if data.ndim == 2: return data else: - windows = self._apply_window(data, self.bin_size) + windows = self.get_signal_window(data, self.window_size) return np.max(np.abs(windows), axis=1) -class Energy(Transformer): - def __init__(self, bin_size): - self.bin_size = bin_size +class Energy(WindowedTransformer): + def __init__(self, window_size): + self.window_size = window_size def __call__(self, data, slice, file): print("calling ", self.__class__.__name__) - if data.ndim < 3: + # If data have only 2 dimensions windowing is not necessary + if data.ndim == 2: return data else: - windows = self._apply_window(data, self.bin_size) + windows = self.get_signal_window(data, self.window_size) return np.sum(np.square(windows), axis=1) -class Raw(Transformer): - def __call__(self, data, slice, file): - print("calling ", self.__class__.__name__) - return np.moveaxis(data, 2, 0) - - class NoMethod(Transformer): def __call__(self, data, slice, file): print("calling ", self.__class__.__name__) return data -class DetectArtifacts(Transformer): +class DetectArtifacts(WindowedTransformer): def __call__(self, data, slice, file): print("calling ", self.__class__.__name__) - if data.ndim < 3: + # If data have only 2 dimensions windowing is not necessary + if data.ndim == 2: return data else: up_limit = file.convert(file.max_volt) * 0.98 @@ -119,9 +119,9 @@ def __call__(self, mask, slice, file): mask[i : i + delay] = 1 masked_data = self.data[mask] - bin_size = self.ms_to_idx(file, 100) + window_size = self.ms_to_idx(file, 100) mea_viewer.MEAViewer().build_view( - file, data=masked_data, view_method="no_method", bin_size=bin_size + file, data=masked_data, view_method="no_method", window_size=window_size ).show() if self.callable: