Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions strax/processing/peak_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@ def split_peaks(peaks, records, to_pe, algorithm='local_minimum',
splitter = dict(local_minimum=LocalMinimumSplitter,
natural_breaks=NaturalBreaksSplitter)[algorithm]()

data_type_is_not_supported = data_type not in ('hitlets', 'peaks')
data_type_is_not_supported = data_type not in ('hitlets', 'peaks', 'merged_s2s')
if data_type_is_not_supported:
raise TypeError(f'Data_type "{data_type}" is not supported.')
return splitter(peaks, records, to_pe, data_type, **kwargs)
concat_orig = data_type != 'merged_s2s'
return splitter(peaks, records, to_pe, data_type, concat_orig, **kwargs)


NO_MORE_SPLITS = -9999999
Expand All @@ -44,6 +45,9 @@ class PeakSplitter:
:param data_type: 'peaks' or 'hitlets'. Specifies whether to use
sum_waveform or get_hitlets_data to compute the waveform of the
new split peaks/hitlets.
:param concat_orig: Return original peaks and new peaks from
splitting concatenated together. Otherwise, return as a tuple
(original peaks, new peaks)
:param do_iterations: maximum number of times peaks are recursively split.
:param min_area: Minimum area to do split. Smaller peaks are not split.

Expand All @@ -55,7 +59,7 @@ class PeakSplitter:
"""
find_split_args_defaults: tuple

def __call__(self, peaks, records, to_pe, data_type,
def __call__(self, peaks, records, to_pe, data_type, concat_orig,
do_iterations=1, min_area=0, **kwargs):
if not len(records) or not len(peaks) or not do_iterations:
return peaks
Expand Down Expand Up @@ -88,7 +92,8 @@ def __call__(self, peaks, records, to_pe, data_type,
orig_dt=records[0]['dt'],
min_area=min_area,
args_options=tuple(args_options),
result_dtype=peaks.dtype)
result_dtype=peaks.dtype,
data_type=data_type,)

if is_split.sum() != 0:
# Found new peaks: compute basic properties
Expand All @@ -98,30 +103,36 @@ def __call__(self, peaks, records, to_pe, data_type,
elif data_type == 'hitlets':
# Add record fields here
new_peaks = strax.sort_by_time(new_peaks) # Hitlets are not necessarily sorted after splitting
new_peaks = strax.get_hitlets_data(new_peaks, records, to_pe)
new_peaks = strax.get_hitlets_data(new_peaks, records, to_pe)
elif data_type == 'merged_s2s':
strax.compute_widths(new_peaks)
# ... and recurse (if needed)
new_peaks = self(new_peaks, records, to_pe, data_type,
new_peaks = self(new_peaks, records, to_pe, data_type, concat_orig=True,
do_iterations=do_iterations - 1,
min_area=min_area, **kwargs)
if np.any(new_peaks['length'] == 0):
raise ValueError('Want to add a new zero-length peak after splitting!')

peaks = strax.sort_by_time(np.concatenate([peaks[~is_split],
new_peaks]))
if concat_orig:
peaks = strax.sort_by_time(np.concatenate([peaks[~is_split],
new_peaks]))
else:
peaks = peaks[~is_split], strax.sort_by_time(new_peaks)

return peaks

@staticmethod
@strax.growing_result(dtype=strax.peak_dtype(), chunk_size=int(1e4))
@numba.jit(nopython=True, nogil=True)
def _split_peaks(split_finder, peaks, orig_dt, is_split, min_area,
args_options,
args_options, data_type,
_result_buffer=None, result_dtype=None):
"""Loop over peaks, pass waveforms to algorithm, construct
new peaks if and where a split occurs.
"""
new_peaks = _result_buffer
offset = 0
reset_dt = data_type != 'merged_s2s'

for p_i, p in enumerate(peaks):
if p['area'] < min_area:
Expand All @@ -143,8 +154,14 @@ def _split_peaks(split_finder, peaks, orig_dt, is_split, min_area,
# Set the dt to the original (lowest) dt first;
# this may change when the sum waveform of the new peak
# is computed
r['dt'] = orig_dt
r['length'] = (split_i - prev_split_i) * p['dt'] / orig_dt
if reset_dt:
r['dt'] = orig_dt
r['length'] = (split_i - prev_split_i) * p['dt'] / orig_dt
else:
r['dt'] = p['dt']
r['length'] = (split_i - prev_split_i)
r['data'][:r['length']] = p['data'][prev_split_i: split_i]
r['area'] = p['data'][prev_split_i: split_i].sum()
r['max_gap'] = -1 # Too lazy to compute this
if r['length'] <= 0:
print(p['data'])
Expand Down