Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 57 additions & 10 deletions bwpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@ def _establish_type(self):
if self.description.startswith("BRW"):
self.__class__ = BRWFile
self._type = "brw"
self.__post_init__()
if self.description.startswith("BXR"):
elif self.description.startswith("BXR"):
self.__class__ = BXRFile
self._type = "bxr"
self.__post_init__()
else:
raise IOError("File is not in BXR/BRW format")
self.__post_init__()

@property
def description(self):
Expand All @@ -54,6 +55,7 @@ def description(self, value):
value = f"{prefix} - " + value
utf8_type = h5py.string_dtype("utf-8", len(value))
value = np.array(value.encode("utf-8"), dtype=utf8_type)
value = np.array(value.encode("utf-8"), dtype=utf8_type)
self.attrs["Description"] = value

@property
Expand Down Expand Up @@ -125,17 +127,25 @@ def __init__(self, slice):
class _TimeSlicer(_Slicer):
def __getitem__(self, instruction):
slice = self._slice._time_slice(instruction)
# self._file.n_frames =
return slice


class _ChannelSlicer(_Slicer):
def __getitem__(self, instruction):
slice = self._slice._channel_slice(instruction)
# self._file.n_channel =
return slice


class Variation:
def __call__(self, data):
return data + self.offset


def variation(slice, *args, **kwargs):
slice.transformations.append(Variation(*args, **kwargs))
return slice


class _Slice:
def __init__(self, file, channels=None, time=None):
self._file = file
Expand All @@ -145,6 +155,8 @@ def __init__(self, file, channels=None, time=None):
if time is None:
time = slice(None)
self._time = time
self._transformations = []
self.bin_size = 100

@property
@functools.cache
Expand Down Expand Up @@ -185,10 +197,27 @@ def data(self):
for i in range(0, len(time_ind), 1000):
end_slice = i + 1000
data[i:end_slice] = self._file.raw[time_ind[i:end_slice]]
return self._file.convert(data.reshape((len(mask), -1)))

def _slice_index(self):
return
digital = data.reshape((len(mask), -1))
analog = self._file.convert(digital)
# Shape (-1, row, cols) because in the next line we'll swapaxes.
# If we used directly shape (row,cols,-1) we would end up with data[:,:,0]:
# [0,3,6] [0,1,2]
# [1,4,7] instead of [3,4,5]
# [2,5,8] [6,7,8]
data = analog.reshape((-1, *self._file.layout.shape))
# in the line above we have [frame, rows, cols], with the line below we get [rows, cols, frame]
# data = np.flip(np.rot90(data.swapaxes(2, 0), -1), 1)

for transformation in self._transformations:
try:
data = transformation(data, self, self._file)
except Exception as e:
raise TransformationError(
f"Error in transformation pipeline {self._transformations}"
f", with {transformation}: {e}"
)
return data

def _time_slice(self, instruction):
if isinstance(instruction, int):
Expand All @@ -203,10 +232,24 @@ def _time_slice(self, instruction):
# start: 1, end: 5, step: 1 --> Slice [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] into Slice [1, 2, 3, 4]
# start: 0, end: 7, step: 2 --> Slice [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] into Slice [0, 2, 4, 6]
# start: 5, end: -1, step: 4 --> Slice [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] into Slice [5, 9]
return _Slice(self._file, self._channels, slice(start, stop, step))
ret = self._copy_slice()
ret._time = slice(start, stop, step)
return ret

def _channel_slice(self, instruction):
return _Slice(self._file, self._channels[instruction], self._time)
ret = self._copy_slice()
ret._channels = self._channels[instruction]
return ret

def _transform(self, transformation):
ret = self._copy_slice()
ret._transformations.append(transformation)
return ret

def _copy_slice(self):
copied = _Slice(self._file, self._channels, self._time)
copied._transformations = self._transformations.copy()
return copied


class BRWFile(File, _Slice):
Expand Down Expand Up @@ -270,4 +313,8 @@ def get_channel_group(self, group_id):
return ChannelGroup._from_bxr(self, data)


class TransformationError(Exception):
pass


__all__ = ["File", "BRWFile", "BXRFile"]
78 changes: 78 additions & 0 deletions bwpy/mea_viewer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from . import signal
import abc
from scipy.stats import norm


class Viewer(abc.ABC):
@abc.abstractmethod
def build_view(self, slice, wiew_method, window_size):
pass


class MEAViewer(Viewer):
colorscale = [[0, "#ebe834"], [1.0, "#eb4034"]]

def build_view(
self, file, slice=None, view_method="amplitude", window_size=100, data=None
):
try:
import plotly.graph_objects as go
except:
raise ModuleNotFoundError(
"You have to install plotly in order to use MEAViewer."
)

fig = go.Figure()
if slice and data:
raise ValueError("slice and data arguments are mutually exclusives.")
if slice:
apply_transformation = getattr(signal, view_method)
signals = apply_transformation(slice, window_size).data
else:
signals = data

max_val = self.get_up_bound(signals, file)
for signal_frame in signals:
fig.add_trace(
go.Heatmap(
visible=False,
z=signal_frame,
zmin=0,
zmax=max_val,
colorscale=self.colorscale,
)
)
return self.format_plot(fig)

def get_up_bound(self, data, file):
up_limit = file.convert(file.max_volt) * 0.98
no_artifacts = data[data < up_limit]
mu, sd = norm.fit(no_artifacts.reshape(-1))
return mu + 2 * sd

def format_plot(self, fig):
# Create and add slider
fig.data[0].visible = True
steps = []
for i in range(len(fig.data)):
step = dict(
method="update",
args=[
{"visible": [False] * len(fig.data)},
{"title": "Time slice: " + str(i)},
], # layout attribute
)
step["args"][0]["visible"][i] = True # Toggle i'th trace to "visible"
steps.append(step)

sliders = [
dict(active=0, currentvalue={"prefix": "Time: "}, pad={"t": 50}, steps=steps)
]

fig.update_layout(
yaxis=dict(scaleanchor="x", autorange="reversed"), sliders=sliders
)
return fig

def show(self, fig):
fig.show()
133 changes: 133 additions & 0 deletions bwpy/signal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from . import mea_viewer
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This import can be removed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?
I use it in the Shutter function because it has to display by default.

import re
import abc
import functools
import numpy as np
import numpy.lib.stride_tricks as np_tricks


__all__ = []


def _transformer_factory(cls):
@functools.wraps(cls)
def transformer_factory(slice, *args, **kwargs):
return slice._transform(cls(*args, **kwargs))

return transformer_factory


class Transformer(abc.ABC):
def __init_subclass__(cls, operator=None, **kwargs) -> None:
super().__init_subclass__(**kwargs)
name = operator or re.sub("([a-z0-9])([A-Z])", r"\1_\2", cls.__name__).lower()
globals()[name] = _transformer_factory(cls)
__all__.append(name)

@abc.abstractmethod
def __call__(self, data, file):
pass


class WindowedTransformer(Transformer):
def get_signal_window(self, data, window_size):
rows = data.shape[1]
cols = data.shape[2]
window_shape = (window_size, rows, cols)
# sliding window returns more complex shape like (num_windows, 1, 1, rows, cols, window_size)
# with the reshape we get rid of the unnecessary complexity (1, 1)
return np_tricks.sliding_window_view(data, window_shape).reshape(
-1, window_size, rows, cols
)


class Variation(WindowedTransformer):
def __init__(self, window_size):
self.window_size = window_size

def __call__(self, data, slice, file):
# If data have only 2 dimensions windowing is not necessary
if data.ndim == 2:
return data
else:
windows = self.get_signal_window(data, self.window_size)
return np.max(np.abs(windows), axis=1) - np.min(np.abs(windows), axis=1)


class Amplitude(WindowedTransformer):
def __init__(self, window_size):
self.window_size = window_size

def __call__(self, data, slice, file):
# If data have only 2 dimensions windowing is not necessary
if data.ndim == 2:
return data
else:
windows = self.get_signal_window(data, self.window_size)
return np.max(np.abs(windows), axis=1)


class Energy(WindowedTransformer):
def __init__(self, window_size):
self.window_size = window_size

def __call__(self, data, slice, file):
if data.ndim == 2:
return data
else:
windows = self.get_signal_window(data, self.window_size)
return np.sum(np.square(windows), axis=1)


class NoMethod(Transformer):
def __call__(self, data, slice, file):
return np.moveaxis(data, 2, 0)


class Noop(Transformer):
"""Noop that doesn't transform the data at all."""

def __call__(self, data, slice, file):
return data


class DetectArtifacts(WindowedTransformer):
def __call__(self, data, slice, file):
# If data have only 2 dimensions windowing is not necessary
if data.ndim == 2:
return data
else:
up_limit = file.convert(file.max_volt) * 0.98
out_bounds = data > up_limit
mask = np.sum(out_bounds, axis=(1, 2)) > 80
return mask


class Shutter(Transformer):
def __init__(self, data, delay_ms, callable=None):
self.delay_ms = delay_ms
self.data = data
self.callable = callable

def __call__(self, mask, slice, file):
if mask.ndim > 1:
raise ValueError("mask must be a 1dim array.")

delay = self.ms_to_idx(file, self.delay_ms)
for i in range(len(mask) - 1, 0, -1):
if mask[i]:
mask[i : i + delay] = 1
masked_data = self.data[mask]

window_size = self.ms_to_idx(file, 100)
mea_viewer.MEAViewer().build_view(
file, data=masked_data, view_method="no_method", window_size=window_size
).show()

if self.callable:
return self.callable(self.data, mask)
else:
return masked_data

def ms_to_idx(self, file, delay_ms):
return int(delay_ms * file.sampling_rate * 1000)
9 changes: 5 additions & 4 deletions tests/test_slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def tearDown(self) -> None:

def test_file_integrity(self):
self.assertEqual(
self.file.data[()].shape, (self.file.n_channels, self.file.n_frames)
self.file.data[()].shape, (*self.file.layout, self.file.n_frames)
)

def test_basic_slicing(self):
Expand All @@ -33,7 +33,7 @@ def test_basic_slicing(self):
"get time slice failed",
)
self.assertEqual(
(self.file.n_channels, 1), self.file.t[1].data.shape, "get time slice failed"
(*self.file.layout, 1), self.file.t[1].data.shape, "get time slice failed"
)
self.assertEqual(
True,
Expand All @@ -51,7 +51,7 @@ def test_basic_slicing(self):

def test_concatenate_slicing(self):
self.assertEqual(
(60 * 60, 90),
(60, 60, 90),
self.file.t[:90].ch[0:60, 0:60].data.shape,
"2 slices concat failed",
)
Expand All @@ -72,12 +72,13 @@ def test_concatenate_slicing(self):

def test_slicing_out_of_range_index(self):
self.assertEqual(
(64 * 2, 100),
(64, 2, 100),
self.file.t[:150].ch[:850, -2:88880].data.shape,
"slicing out of range index failed",
)

def test_empty_slice(self):
empty_slice = bwpy.File(f"{path}/test_samples/empty.brw", "r")
empty_slice.t[:150].ch[:850, :].data
with self.assertRaises(ValueError, "It should throw a ValueError") as context:
empty_slice.t[:150].ch[:850, :].data