diff --git a/strax/processing/peak_splitting.py b/strax/processing/peak_splitting.py index 67ca2c509..44ad1b135 100644 --- a/strax/processing/peak_splitting.py +++ b/strax/processing/peak_splitting.py @@ -26,10 +26,11 @@ def split_peaks(peaks, records, to_pe, algorithm='local_minimum', splitter = dict(local_minimum=LocalMinimumSplitter, natural_breaks=NaturalBreaksSplitter)[algorithm]() - data_type_is_not_supported = data_type not in ('hitlets', 'peaks') + data_type_is_not_supported = data_type not in ('hitlets', 'peaks', 'merged_s2s') if data_type_is_not_supported: raise TypeError(f'Data_type "{data_type}" is not supported.') - return splitter(peaks, records, to_pe, data_type, **kwargs) + concat_orig = data_type != 'merged_s2s' + return splitter(peaks, records, to_pe, data_type, concat_orig, **kwargs) NO_MORE_SPLITS = -9999999 @@ -44,6 +45,9 @@ class PeakSplitter: :param data_type: 'peaks' or 'hitlets'. Specifies whether to use sum_waveform or get_hitlets_data to compute the waveform of the new split peaks/hitlets. + :param concat_orig: Return original peaks and new peaks from + splitting concatenated together. Otherwise, return as a tuple + (original peaks, new peaks) :param do_iterations: maximum number of times peaks are recursively split. :param min_area: Minimum area to do split. Smaller peaks are not split. @@ -55,7 +59,7 @@ class PeakSplitter: """ find_split_args_defaults: tuple - def __call__(self, peaks, records, to_pe, data_type, + def __call__(self, peaks, records, to_pe, data_type, concat_orig, do_iterations=1, min_area=0, **kwargs): if not len(records) or not len(peaks) or not do_iterations: return peaks @@ -88,7 +92,8 @@ def __call__(self, peaks, records, to_pe, data_type, orig_dt=records[0]['dt'], min_area=min_area, args_options=tuple(args_options), - result_dtype=peaks.dtype) + result_dtype=peaks.dtype, + data_type=data_type,) if is_split.sum() != 0: # Found new peaks: compute basic properties @@ -98,16 +103,21 @@ def __call__(self, peaks, records, to_pe, data_type, elif data_type == 'hitlets': # Add record fields here new_peaks = strax.sort_by_time(new_peaks) # Hitlets are not necessarily sorted after splitting - new_peaks = strax.get_hitlets_data(new_peaks, records, to_pe) + new_peaks = strax.get_hitlets_data(new_peaks, records, to_pe) + elif data_type == 'merged_s2s': + strax.compute_widths(new_peaks) # ... and recurse (if needed) - new_peaks = self(new_peaks, records, to_pe, data_type, + new_peaks = self(new_peaks, records, to_pe, data_type, concat_orig=True, do_iterations=do_iterations - 1, min_area=min_area, **kwargs) if np.any(new_peaks['length'] == 0): raise ValueError('Want to add a new zero-length peak after splitting!') - peaks = strax.sort_by_time(np.concatenate([peaks[~is_split], - new_peaks])) + if concat_orig: + peaks = strax.sort_by_time(np.concatenate([peaks[~is_split], + new_peaks])) + else: + peaks = peaks[~is_split], strax.sort_by_time(new_peaks) return peaks @@ -115,13 +125,14 @@ def __call__(self, peaks, records, to_pe, data_type, @strax.growing_result(dtype=strax.peak_dtype(), chunk_size=int(1e4)) @numba.jit(nopython=True, nogil=True) def _split_peaks(split_finder, peaks, orig_dt, is_split, min_area, - args_options, + args_options, data_type, _result_buffer=None, result_dtype=None): """Loop over peaks, pass waveforms to algorithm, construct new peaks if and where a split occurs. """ new_peaks = _result_buffer offset = 0 + reset_dt = data_type != 'merged_s2s' for p_i, p in enumerate(peaks): if p['area'] < min_area: @@ -143,8 +154,14 @@ def _split_peaks(split_finder, peaks, orig_dt, is_split, min_area, # Set the dt to the original (lowest) dt first; # this may change when the sum waveform of the new peak # is computed - r['dt'] = orig_dt - r['length'] = (split_i - prev_split_i) * p['dt'] / orig_dt + if reset_dt: + r['dt'] = orig_dt + r['length'] = (split_i - prev_split_i) * p['dt'] / orig_dt + else: + r['dt'] = p['dt'] + r['length'] = (split_i - prev_split_i) + r['data'][:r['length']] = p['data'][prev_split_i: split_i] + r['area'] = p['data'][prev_split_i: split_i].sum() r['max_gap'] = -1 # Too lazy to compute this if r['length'] <= 0: print(p['data'])