From ade2998df790438d890bcb2f7e1afda828247f97 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 3 May 2023 22:21:55 +0200 Subject: [PATCH] Start memmap refacoring. --- neo/rawio/openephysbinaryrawio.py | 39 ++++++++++++++++++++++++------- neo/rawio/spikeglxrawio.py | 34 ++++++++++++++++++++------- neo/rawio/utils.py | 34 +++++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 17 deletions(-) create mode 100644 neo/rawio/utils.py diff --git a/neo/rawio/openephysbinaryrawio.py b/neo/rawio/openephysbinaryrawio.py index 650c96672..c2b7c87ff 100644 --- a/neo/rawio/openephysbinaryrawio.py +++ b/neo/rawio/openephysbinaryrawio.py @@ -20,6 +20,7 @@ from .baserawio import (BaseRawIO, _signal_channel_dtype, _signal_stream_dtype, _spike_channel_dtype, _event_channel_dtype) +from .utils import create_memmap_buffer, get_memmap_shape class OpenEphysBinaryRawIO(BaseRawIO): @@ -133,8 +134,8 @@ def _parse_header(self): for seg_index in range(nb_segment_per_block[block_index]): for stream_index, d in self._sig_streams[block_index][seg_index].items(): num_channels = len(d['channels']) - memmap_sigs = np.memmap(d['raw_filename'], d['dtype'], - order='C', mode='r').reshape(-1, num_channels) + #~ memmap_sigs = np.memmap(d['raw_filename'], d['dtype'], + #~ order='C', mode='r').reshape(-1, num_channels) channel_names = [ch["channel_name"] for ch in d["channels"]] # if there is a sync channel and it should not be loaded, # find the right channel index and slice the memmap @@ -145,12 +146,19 @@ def _parse_header(self): # only sync channel in last position is supported to keep memmap if sync_channel_index == num_channels - 1: - memmap_sigs = memmap_sigs[:, :-1] + #~ memmap_sigs = memmap_sigs[:, :-1] + #~ pass + d['remove_last_channel'] = True else: raise NotImplementedError("SYNC channel removal is only supported " "when the sync channel is in the last " "position") - d['memmap'] = memmap_sigs + else: + d['remove_last_channel'] = False + # d['memmap'] = memmap_sigs + shape = get_memmap_shape(d['raw_filename'], d['dtype'], num_channels=num_channels) + fid = open(d['raw_filename'], mode="rb") + d['memmap_args'] = (fid, shape, np.dtype(d['dtype']), 0) # events zone @@ -248,7 +256,9 @@ def _parse_header(self): # loop over signals for stream_index, d in self._sig_streams[block_index][seg_index].items(): t_start = d['t_start'] - dur = d['memmap'].shape[0] / float(d['sample_rate']) + #~ dur = d['memmap'].shape[0] / float(d['sample_rate']) + memmap_sigs = create_memmap_buffer(*d['memmap_args']) + dur = memmap_sigs.shape[0] / float(d['sample_rate']) t_stop = t_start + dur if global_t_start is None or global_t_start > t_start: global_t_start = t_start @@ -327,6 +337,14 @@ def _parse_header(self): arr_ann = arr_ann[selected_indices] ev_ann['__array_annotations__'][k] = arr_ann + def __del__(self): + # need an explicit close + for block_index in range(self.header['nb_block']): + for seg_index in range(self.header['nb_segment'][block_index]): + for stream_index, d in self._sig_streams[block_index][seg_index].items(): + fid, *_ = d['memmap_args'] + fid.close() + def _segment_t_start(self, block_index, seg_index): return self._t_start_segments[block_index][seg_index] @@ -343,8 +361,9 @@ def _channels_to_group_id(self, channel_indexes): return group_id def _get_signal_size(self, block_index, seg_index, stream_index): - sigs = self._sig_streams[block_index][seg_index][stream_index]['memmap'] - return sigs.shape[0] + #~ sigs = self._sig_streams[block_index][seg_index][stream_index]['memmap'] + memmap_sigs = create_memmap_buffer(*self._sig_streams[block_index][seg_index][stream_index]['memmap_args']) + return memmap_sigs.shape[0] def _get_signal_t_start(self, block_index, seg_index, stream_index): t_start = self._sig_streams[block_index][seg_index][stream_index]['t_start'] @@ -352,7 +371,11 @@ def _get_signal_t_start(self, block_index, seg_index, stream_index): def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index, channel_indexes): - sigs = self._sig_streams[block_index][seg_index][stream_index]['memmap'] + #~ sigs = self._sig_streams[block_index][seg_index][stream_index]['memmap'] + d = self._sig_streams[block_index][seg_index][stream_index] + sigs = create_memmap_buffer(*d['memmap_args']) + if d['remove_last_channel']: + sigs = sigs[:, :-1] sigs = sigs[i_start:i_stop, :] if channel_indexes is not None: sigs = sigs[:, channel_indexes] diff --git a/neo/rawio/spikeglxrawio.py b/neo/rawio/spikeglxrawio.py index 495b428b2..9922b467a 100644 --- a/neo/rawio/spikeglxrawio.py +++ b/neo/rawio/spikeglxrawio.py @@ -52,6 +52,7 @@ from .baserawio import (BaseRawIO, _signal_channel_dtype, _signal_stream_dtype, _spike_channel_dtype, _event_channel_dtype) +from .utils import create_memmap_buffer, get_memmap_shape from pathlib import Path import os @@ -91,7 +92,8 @@ def _parse_header(self): nb_segment = np.unique([info['seg_index'] for info in self.signals_info_list]).size - self._memmaps = {} + # self._memmaps = {} + self._memmap_args = {} self.signals_info_dict = {} for info in self.signals_info_list: # key is (seg_index, stream_name) @@ -100,11 +102,14 @@ def _parse_header(self): self.signals_info_dict[key] = info # create memmap - data = np.memmap(info['bin_file'], dtype='int16', mode='r', offset=0, order='C') - # this should be (info['sample_length'], info['num_chan']) - # be some file are shorten - data = data.reshape(-1, info['num_chan']) - self._memmaps[key] = data + #~ data = np.memmap(info['bin_file'], dtype='int16', mode='r', offset=0, order='C') + #~ # this should be (info['sample_length'], info['num_chan']) + #~ # be some file are shorten + #~ data = data.reshape(-1, info['num_chan']) + #~ self._memmaps[key] = data + shape = get_memmap_shape(info['bin_file'], 'int16', num_channels= info['num_chan']) + fid = open(info['bin_file'], "rb") + self._memmap_args[key] = (fid, shape, np.dtype('int16'), 0) # create channel header signal_streams = [] @@ -182,7 +187,13 @@ def _parse_header(self): loc = np.concatenate((loc, [[0., 0.]]), axis=0) for ndim in range(loc.shape[1]): sig_ann['__array_annotations__'][f'channel_location_{ndim}'] = loc[:, ndim] - + + def __del__(self): + # need an explicit close + for k, args in self._memmap_args.items(): + fid, *_ = args + fid.close() + def _segment_t_start(self, block_index, seg_index): return 0. @@ -191,7 +202,9 @@ def _segment_t_stop(self, block_index, seg_index): def _get_signal_size(self, block_index, seg_index, stream_index): stream_id = self.header['signal_streams'][stream_index]['id'] - memmap = self._memmaps[seg_index, stream_id] + #~ memmap = self._memmaps[seg_index, stream_id] + key = (seg_index, stream_id) + memmap = create_memmap_buffer(*self._memmap_args[key]) return int(memmap.shape[0]) def _get_signal_t_start(self, block_index, seg_index, stream_index): @@ -200,7 +213,10 @@ def _get_signal_t_start(self, block_index, seg_index, stream_index): def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index, channel_indexes): stream_id = self.header['signal_streams'][stream_index]['id'] - memmap = self._memmaps[seg_index, stream_id] + #~ memmap = self._memmaps[seg_index, stream_id] + key = (seg_index, stream_id) + memmap = create_memmap_buffer(*self._memmap_args[key]) + if channel_indexes is None: if self.load_sync_channel: channel_selection = slice(None) diff --git a/neo/rawio/utils.py b/neo/rawio/utils.py new file mode 100644 index 000000000..23c1c5d0c --- /dev/null +++ b/neo/rawio/utils.py @@ -0,0 +1,34 @@ +import mmap +import numpy as np + +def get_memmap_shape(filename, dtype, num_channels=None, offset=0): + dtype = np.dtype(dtype) + with open(filename, mode='rb') as f: + f.seek(0, 2) + flen = f.tell() + bytes = flen - offset + if bytes % dtype.itemsize != 0: + raise ValueError("Size of available data is not a multiple of the data-type size.") + size = bytes // dtype.itemsize + if num_channels is None: + shape = (size,) + else: + shape = (size // num_channels, num_channels) + return shape + +def create_memmap_buffer(fid, shape, dtype, offset=0): + """ + A function that mimic the np.memmap but: + * use an already opened file as input without checking the file size. + * it handles also only the ready only case + This should be faster. + """ + dtype = np.dtype(dtype) + size = np.prod(shape, dtype='int64') + bytes = dtype.itemsize * size + start = offset - offset % mmap.ALLOCATIONGRANULARITY + bytes -= start + array_offset = offset - start + mmap_buffer = mmap.mmap(fid.fileno(), bytes, access=mmap.ACCESS_READ, offset=start) + arr = np.ndarray.__new__(np.ndarray, shape, dtype=dtype, buffer=mmap_buffer, offset=array_offset, order='c') + return arr