diff --git a/bwpy/__init__.py b/bwpy/__init__.py index 676f265..7d88ea0 100644 --- a/bwpy/__init__.py +++ b/bwpy/__init__.py @@ -31,11 +31,12 @@ def _establish_type(self): if self.description.startswith("BRW"): self.__class__ = BRWFile self._type = "brw" - self.__post_init__() - if self.description.startswith("BXR"): + elif self.description.startswith("BXR"): self.__class__ = BXRFile self._type = "bxr" - self.__post_init__() + else: + raise IOError("File is not in BXR/BRW format") + self.__post_init__() @property def description(self): @@ -54,6 +55,7 @@ def description(self, value): value = f"{prefix} - " + value utf8_type = h5py.string_dtype("utf-8", len(value)) value = np.array(value.encode("utf-8"), dtype=utf8_type) + value = np.array(value.encode("utf-8"), dtype=utf8_type) self.attrs["Description"] = value @property @@ -125,17 +127,25 @@ def __init__(self, slice): class _TimeSlicer(_Slicer): def __getitem__(self, instruction): slice = self._slice._time_slice(instruction) - # self._file.n_frames = return slice class _ChannelSlicer(_Slicer): def __getitem__(self, instruction): slice = self._slice._channel_slice(instruction) - # self._file.n_channel = return slice +class Variation: + def __call__(self, data): + return data + self.offset + + +def variation(slice, *args, **kwargs): + slice.transformations.append(Variation(*args, **kwargs)) + return slice + + class _Slice: def __init__(self, file, channels=None, time=None): self._file = file @@ -145,6 +155,8 @@ def __init__(self, file, channels=None, time=None): if time is None: time = slice(None) self._time = time + self._transformations = [] + self.bin_size = 100 @property @functools.cache @@ -185,10 +197,27 @@ def data(self): for i in range(0, len(time_ind), 1000): end_slice = i + 1000 data[i:end_slice] = self._file.raw[time_ind[i:end_slice]] - return self._file.convert(data.reshape((len(mask), -1))) - def _slice_index(self): - return + digital = data.reshape((len(mask), -1)) + analog = self._file.convert(digital) + # Shape (-1, row, cols) because in the next line we'll swapaxes. + # If we used directly shape (row,cols,-1) we would end up with data[:,:,0]: + # [0,3,6] [0,1,2] + # [1,4,7] instead of [3,4,5] + # [2,5,8] [6,7,8] + data = analog.reshape((-1, *self._file.layout.shape)) + # in the line above we have [frame, rows, cols], with the line below we get [rows, cols, frame] + # data = np.flip(np.rot90(data.swapaxes(2, 0), -1), 1) + + for transformation in self._transformations: + try: + data = transformation(data, self, self._file) + except Exception as e: + raise TransformationError( + f"Error in transformation pipeline {self._transformations}" + f", with {transformation}: {e}" + ) + return data def _time_slice(self, instruction): if isinstance(instruction, int): @@ -203,10 +232,24 @@ def _time_slice(self, instruction): # start: 1, end: 5, step: 1 --> Slice [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] into Slice [1, 2, 3, 4] # start: 0, end: 7, step: 2 --> Slice [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] into Slice [0, 2, 4, 6] # start: 5, end: -1, step: 4 --> Slice [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] into Slice [5, 9] - return _Slice(self._file, self._channels, slice(start, stop, step)) + ret = self._copy_slice() + ret._time = slice(start, stop, step) + return ret def _channel_slice(self, instruction): - return _Slice(self._file, self._channels[instruction], self._time) + ret = self._copy_slice() + ret._channels = self._channels[instruction] + return ret + + def _transform(self, transformation): + ret = self._copy_slice() + ret._transformations.append(transformation) + return ret + + def _copy_slice(self): + copied = _Slice(self._file, self._channels, self._time) + copied._transformations = self._transformations.copy() + return copied class BRWFile(File, _Slice): @@ -270,4 +313,8 @@ def get_channel_group(self, group_id): return ChannelGroup._from_bxr(self, data) +class TransformationError(Exception): + pass + + __all__ = ["File", "BRWFile", "BXRFile"] diff --git a/bwpy/mea_viewer.py b/bwpy/mea_viewer.py new file mode 100644 index 0000000..b246cf1 --- /dev/null +++ b/bwpy/mea_viewer.py @@ -0,0 +1,78 @@ +from . import signal +import abc +from scipy.stats import norm + + +class Viewer(abc.ABC): + @abc.abstractmethod + def build_view(self, slice, wiew_method, window_size): + pass + + +class MEAViewer(Viewer): + colorscale = [[0, "#ebe834"], [1.0, "#eb4034"]] + + def build_view( + self, file, slice=None, view_method="amplitude", window_size=100, data=None + ): + try: + import plotly.graph_objects as go + except: + raise ModuleNotFoundError( + "You have to install plotly in order to use MEAViewer." + ) + + fig = go.Figure() + if slice and data: + raise ValueError("slice and data arguments are mutually exclusives.") + if slice: + apply_transformation = getattr(signal, view_method) + signals = apply_transformation(slice, window_size).data + else: + signals = data + + max_val = self.get_up_bound(signals, file) + for signal_frame in signals: + fig.add_trace( + go.Heatmap( + visible=False, + z=signal_frame, + zmin=0, + zmax=max_val, + colorscale=self.colorscale, + ) + ) + return self.format_plot(fig) + + def get_up_bound(self, data, file): + up_limit = file.convert(file.max_volt) * 0.98 + no_artifacts = data[data < up_limit] + mu, sd = norm.fit(no_artifacts.reshape(-1)) + return mu + 2 * sd + + def format_plot(self, fig): + # Create and add slider + fig.data[0].visible = True + steps = [] + for i in range(len(fig.data)): + step = dict( + method="update", + args=[ + {"visible": [False] * len(fig.data)}, + {"title": "Time slice: " + str(i)}, + ], # layout attribute + ) + step["args"][0]["visible"][i] = True # Toggle i'th trace to "visible" + steps.append(step) + + sliders = [ + dict(active=0, currentvalue={"prefix": "Time: "}, pad={"t": 50}, steps=steps) + ] + + fig.update_layout( + yaxis=dict(scaleanchor="x", autorange="reversed"), sliders=sliders + ) + return fig + + def show(self, fig): + fig.show() diff --git a/bwpy/signal.py b/bwpy/signal.py new file mode 100644 index 0000000..b94b974 --- /dev/null +++ b/bwpy/signal.py @@ -0,0 +1,133 @@ +from . import mea_viewer +import re +import abc +import functools +import numpy as np +import numpy.lib.stride_tricks as np_tricks + + +__all__ = [] + + +def _transformer_factory(cls): + @functools.wraps(cls) + def transformer_factory(slice, *args, **kwargs): + return slice._transform(cls(*args, **kwargs)) + + return transformer_factory + + +class Transformer(abc.ABC): + def __init_subclass__(cls, operator=None, **kwargs) -> None: + super().__init_subclass__(**kwargs) + name = operator or re.sub("([a-z0-9])([A-Z])", r"\1_\2", cls.__name__).lower() + globals()[name] = _transformer_factory(cls) + __all__.append(name) + + @abc.abstractmethod + def __call__(self, data, file): + pass + + +class WindowedTransformer(Transformer): + def get_signal_window(self, data, window_size): + rows = data.shape[1] + cols = data.shape[2] + window_shape = (window_size, rows, cols) + # sliding window returns more complex shape like (num_windows, 1, 1, rows, cols, window_size) + # with the reshape we get rid of the unnecessary complexity (1, 1) + return np_tricks.sliding_window_view(data, window_shape).reshape( + -1, window_size, rows, cols + ) + + +class Variation(WindowedTransformer): + def __init__(self, window_size): + self.window_size = window_size + + def __call__(self, data, slice, file): + # If data have only 2 dimensions windowing is not necessary + if data.ndim == 2: + return data + else: + windows = self.get_signal_window(data, self.window_size) + return np.max(np.abs(windows), axis=1) - np.min(np.abs(windows), axis=1) + + +class Amplitude(WindowedTransformer): + def __init__(self, window_size): + self.window_size = window_size + + def __call__(self, data, slice, file): + # If data have only 2 dimensions windowing is not necessary + if data.ndim == 2: + return data + else: + windows = self.get_signal_window(data, self.window_size) + return np.max(np.abs(windows), axis=1) + + +class Energy(WindowedTransformer): + def __init__(self, window_size): + self.window_size = window_size + + def __call__(self, data, slice, file): + if data.ndim == 2: + return data + else: + windows = self.get_signal_window(data, self.window_size) + return np.sum(np.square(windows), axis=1) + + +class NoMethod(Transformer): + def __call__(self, data, slice, file): + return np.moveaxis(data, 2, 0) + + +class Noop(Transformer): + """Noop that doesn't transform the data at all.""" + + def __call__(self, data, slice, file): + return data + + +class DetectArtifacts(WindowedTransformer): + def __call__(self, data, slice, file): + # If data have only 2 dimensions windowing is not necessary + if data.ndim == 2: + return data + else: + up_limit = file.convert(file.max_volt) * 0.98 + out_bounds = data > up_limit + mask = np.sum(out_bounds, axis=(1, 2)) > 80 + return mask + + +class Shutter(Transformer): + def __init__(self, data, delay_ms, callable=None): + self.delay_ms = delay_ms + self.data = data + self.callable = callable + + def __call__(self, mask, slice, file): + if mask.ndim > 1: + raise ValueError("mask must be a 1dim array.") + + delay = self.ms_to_idx(file, self.delay_ms) + for i in range(len(mask) - 1, 0, -1): + if mask[i]: + mask[i : i + delay] = 1 + masked_data = self.data[mask] + + window_size = self.ms_to_idx(file, 100) + mea_viewer.MEAViewer().build_view( + file, data=masked_data, view_method="no_method", window_size=window_size + ).show() + + if self.callable: + return self.callable(self.data, mask) + else: + return masked_data + + def ms_to_idx(self, file, delay_ms): + return int(delay_ms * file.sampling_rate * 1000) diff --git a/tests/test_slicing.py b/tests/test_slicing.py index 37e2791..81b9305 100644 --- a/tests/test_slicing.py +++ b/tests/test_slicing.py @@ -17,7 +17,7 @@ def tearDown(self) -> None: def test_file_integrity(self): self.assertEqual( - self.file.data[()].shape, (self.file.n_channels, self.file.n_frames) + self.file.data[()].shape, (*self.file.layout, self.file.n_frames) ) def test_basic_slicing(self): @@ -33,7 +33,7 @@ def test_basic_slicing(self): "get time slice failed", ) self.assertEqual( - (self.file.n_channels, 1), self.file.t[1].data.shape, "get time slice failed" + (*self.file.layout, 1), self.file.t[1].data.shape, "get time slice failed" ) self.assertEqual( True, @@ -51,7 +51,7 @@ def test_basic_slicing(self): def test_concatenate_slicing(self): self.assertEqual( - (60 * 60, 90), + (60, 60, 90), self.file.t[:90].ch[0:60, 0:60].data.shape, "2 slices concat failed", ) @@ -72,12 +72,13 @@ def test_concatenate_slicing(self): def test_slicing_out_of_range_index(self): self.assertEqual( - (64 * 2, 100), + (64, 2, 100), self.file.t[:150].ch[:850, -2:88880].data.shape, "slicing out of range index failed", ) def test_empty_slice(self): empty_slice = bwpy.File(f"{path}/test_samples/empty.brw", "r") + empty_slice.t[:150].ch[:850, :].data with self.assertRaises(ValueError, "It should throw a ValueError") as context: empty_slice.t[:150].ch[:850, :].data