From 51eb223027fc2ba1f4cd582f45d5f867e617736a Mon Sep 17 00:00:00 2001 From: Tianyu Zhu Date: Tue, 24 Aug 2021 12:36:31 -0500 Subject: [PATCH 1/5] changes for splitting merged s2s --- strax/processing/peak_merging.py | 2 +- strax/processing/peak_splitting.py | 37 ++++++++++++++++++++---------- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/strax/processing/peak_merging.py b/strax/processing/peak_merging.py index 4b46971eb..15d72f849 100644 --- a/strax/processing/peak_merging.py +++ b/strax/processing/peak_merging.py @@ -92,7 +92,7 @@ def replace_merged(orig, merge): return orig skip_windows = strax.touching_windows(orig, merge) - skip_n = np.diff(skip_windows, axis=1).sum() + skip_n = len(np.unique(np.concatenate([np.arange(l, r) for l, r in skip_windows]))) result = np.zeros(len(orig) - skip_n + len(merge), dtype=orig.dtype) _replace_merged(result, orig, merge, skip_windows) diff --git a/strax/processing/peak_splitting.py b/strax/processing/peak_splitting.py index 67ca2c509..8330f30c9 100644 --- a/strax/processing/peak_splitting.py +++ b/strax/processing/peak_splitting.py @@ -26,10 +26,11 @@ def split_peaks(peaks, records, to_pe, algorithm='local_minimum', splitter = dict(local_minimum=LocalMinimumSplitter, natural_breaks=NaturalBreaksSplitter)[algorithm]() - data_type_is_not_supported = data_type not in ('hitlets', 'peaks') + data_type_is_not_supported = data_type not in ('hitlets', 'peaks', 'merged_s2s') if data_type_is_not_supported: raise TypeError(f'Data_type "{data_type}" is not supported.') - return splitter(peaks, records, to_pe, data_type, **kwargs) + concat_orig = data_type != 'merged_s2s' + return splitter(peaks, records, to_pe, data_type, concat_orig, **kwargs) NO_MORE_SPLITS = -9999999 @@ -55,9 +56,9 @@ class PeakSplitter: """ find_split_args_defaults: tuple - def __call__(self, peaks, records, to_pe, data_type, + def __call__(self, peaks, records, to_pe, data_type, concat_orig, do_iterations=1, min_area=0, **kwargs): - if not len(records) or not len(peaks) or not do_iterations: + if not len(peaks) or not do_iterations: return peaks # Build the *args tuple for self.find_split_points from kwargs @@ -85,7 +86,6 @@ def __call__(self, peaks, records, to_pe, data_type, split_finder=self.find_split_points, peaks=peaks, is_split=is_split, - orig_dt=records[0]['dt'], min_area=min_area, args_options=tuple(args_options), result_dtype=peaks.dtype) @@ -93,28 +93,39 @@ def __call__(self, peaks, records, to_pe, data_type, if is_split.sum() != 0: # Found new peaks: compute basic properties if data_type == 'peaks': + orig_dt = records[0]['dt'] + new_peaks['length'] = new_peaks['length'] * new_peaks['dt'] / orig_dt + new_peaks['dt'] = orig_dt strax.sum_waveform(new_peaks, records, to_pe) strax.compute_widths(new_peaks) elif data_type == 'hitlets': # Add record fields here + orig_dt = records[0]['dt'] + new_peaks['length'] = new_peaks['length'] * new_peaks['dt'] / orig_dt + new_peaks['dt'] = orig_dt new_peaks = strax.sort_by_time(new_peaks) # Hitlets are not necessarily sorted after splitting - new_peaks = strax.get_hitlets_data(new_peaks, records, to_pe) + new_peaks = strax.get_hitlets_data(new_peaks, records, to_pe) + elif data_type == 'merged_s2s': + strax.compute_widths(new_peaks) # ... and recurse (if needed) - new_peaks = self(new_peaks, records, to_pe, data_type, + new_peaks = self(new_peaks, records, to_pe, data_type, concat_orig=True, do_iterations=do_iterations - 1, min_area=min_area, **kwargs) if np.any(new_peaks['length'] == 0): raise ValueError('Want to add a new zero-length peak after splitting!') - peaks = strax.sort_by_time(np.concatenate([peaks[~is_split], - new_peaks])) + if concat_orig: + peaks = strax.sort_by_time(np.concatenate([peaks[~is_split], + new_peaks])) + else: + peaks = peaks[~is_split], strax.sort_by_time(new_peaks) return peaks @staticmethod @strax.growing_result(dtype=strax.peak_dtype(), chunk_size=int(1e4)) @numba.jit(nopython=True, nogil=True) - def _split_peaks(split_finder, peaks, orig_dt, is_split, min_area, + def _split_peaks(split_finder, peaks, is_split, min_area, args_options, _result_buffer=None, result_dtype=None): """Loop over peaks, pass waveforms to algorithm, construct @@ -143,8 +154,10 @@ def _split_peaks(split_finder, peaks, orig_dt, is_split, min_area, # Set the dt to the original (lowest) dt first; # this may change when the sum waveform of the new peak # is computed - r['dt'] = orig_dt - r['length'] = (split_i - prev_split_i) * p['dt'] / orig_dt + r['dt'] = p['dt'] + r['length'] = (split_i - prev_split_i) + r['data'][:r['length']] = p['data'][prev_split_i: split_i] + r['area'] = p['data'][prev_split_i: split_i].sum() r['max_gap'] = -1 # Too lazy to compute this if r['length'] <= 0: print(p['data']) From 5fd83037de35fbc1520e53d12802a71793d2c6b3 Mon Sep 17 00:00:00 2001 From: Tianyu Zhu Date: Tue, 24 Aug 2021 12:46:59 -0500 Subject: [PATCH 2/5] revert testing features --- strax/processing/peak_merging.py | 2 +- strax/processing/peak_splitting.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/strax/processing/peak_merging.py b/strax/processing/peak_merging.py index 15d72f849..4b46971eb 100644 --- a/strax/processing/peak_merging.py +++ b/strax/processing/peak_merging.py @@ -92,7 +92,7 @@ def replace_merged(orig, merge): return orig skip_windows = strax.touching_windows(orig, merge) - skip_n = len(np.unique(np.concatenate([np.arange(l, r) for l, r in skip_windows]))) + skip_n = np.diff(skip_windows, axis=1).sum() result = np.zeros(len(orig) - skip_n + len(merge), dtype=orig.dtype) _replace_merged(result, orig, merge, skip_windows) diff --git a/strax/processing/peak_splitting.py b/strax/processing/peak_splitting.py index 8330f30c9..9c1e2b7db 100644 --- a/strax/processing/peak_splitting.py +++ b/strax/processing/peak_splitting.py @@ -58,7 +58,7 @@ class PeakSplitter: def __call__(self, peaks, records, to_pe, data_type, concat_orig, do_iterations=1, min_area=0, **kwargs): - if not len(peaks) or not do_iterations: + if not len(records) or not len(peaks) or not do_iterations: return peaks # Build the *args tuple for self.find_split_points from kwargs From ac38efc50010e52a355b8aa2e37236524e2892f7 Mon Sep 17 00:00:00 2001 From: Tianyu Zhu Date: Tue, 24 Aug 2021 12:55:59 -0500 Subject: [PATCH 3/5] update docstring --- strax/processing/peak_splitting.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/strax/processing/peak_splitting.py b/strax/processing/peak_splitting.py index 9c1e2b7db..fa11a16db 100644 --- a/strax/processing/peak_splitting.py +++ b/strax/processing/peak_splitting.py @@ -45,6 +45,9 @@ class PeakSplitter: :param data_type: 'peaks' or 'hitlets'. Specifies whether to use sum_waveform or get_hitlets_data to compute the waveform of the new split peaks/hitlets. + :param concat_orig: Return original peaks and new peaks from + splitting concatenated together. Otherwise, return as a tuple + (original peaks, new peaks) :param do_iterations: maximum number of times peaks are recursively split. :param min_area: Minimum area to do split. Smaller peaks are not split. From b69a579b004f6195ee1d817dbdfa51db43a4e939 Mon Sep 17 00:00:00 2001 From: Tianyu Zhu Date: Tue, 24 Aug 2021 16:14:03 -0500 Subject: [PATCH 4/5] set hitlet data to 0 --- strax/processing/peak_splitting.py | 1 + 1 file changed, 1 insertion(+) diff --git a/strax/processing/peak_splitting.py b/strax/processing/peak_splitting.py index fa11a16db..52bf9a556 100644 --- a/strax/processing/peak_splitting.py +++ b/strax/processing/peak_splitting.py @@ -106,6 +106,7 @@ def __call__(self, peaks, records, to_pe, data_type, concat_orig, orig_dt = records[0]['dt'] new_peaks['length'] = new_peaks['length'] * new_peaks['dt'] / orig_dt new_peaks['dt'] = orig_dt + new_peaks['data'][:] = 0 new_peaks = strax.sort_by_time(new_peaks) # Hitlets are not necessarily sorted after splitting new_peaks = strax.get_hitlets_data(new_peaks, records, to_pe) elif data_type == 'merged_s2s': From 6979ffcd90123582d6bc43330da503930ffdab8c Mon Sep 17 00:00:00 2001 From: Tianyu Zhu Date: Tue, 24 Aug 2021 16:58:32 -0500 Subject: [PATCH 5/5] reduce code complexity --- strax/processing/peak_splitting.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/strax/processing/peak_splitting.py b/strax/processing/peak_splitting.py index 52bf9a556..44ad1b135 100644 --- a/strax/processing/peak_splitting.py +++ b/strax/processing/peak_splitting.py @@ -89,24 +89,19 @@ def __call__(self, peaks, records, to_pe, data_type, concat_orig, split_finder=self.find_split_points, peaks=peaks, is_split=is_split, + orig_dt=records[0]['dt'], min_area=min_area, args_options=tuple(args_options), - result_dtype=peaks.dtype) + result_dtype=peaks.dtype, + data_type=data_type,) if is_split.sum() != 0: # Found new peaks: compute basic properties if data_type == 'peaks': - orig_dt = records[0]['dt'] - new_peaks['length'] = new_peaks['length'] * new_peaks['dt'] / orig_dt - new_peaks['dt'] = orig_dt strax.sum_waveform(new_peaks, records, to_pe) strax.compute_widths(new_peaks) elif data_type == 'hitlets': # Add record fields here - orig_dt = records[0]['dt'] - new_peaks['length'] = new_peaks['length'] * new_peaks['dt'] / orig_dt - new_peaks['dt'] = orig_dt - new_peaks['data'][:] = 0 new_peaks = strax.sort_by_time(new_peaks) # Hitlets are not necessarily sorted after splitting new_peaks = strax.get_hitlets_data(new_peaks, records, to_pe) elif data_type == 'merged_s2s': @@ -129,14 +124,15 @@ def __call__(self, peaks, records, to_pe, data_type, concat_orig, @staticmethod @strax.growing_result(dtype=strax.peak_dtype(), chunk_size=int(1e4)) @numba.jit(nopython=True, nogil=True) - def _split_peaks(split_finder, peaks, is_split, min_area, - args_options, + def _split_peaks(split_finder, peaks, orig_dt, is_split, min_area, + args_options, data_type, _result_buffer=None, result_dtype=None): """Loop over peaks, pass waveforms to algorithm, construct new peaks if and where a split occurs. """ new_peaks = _result_buffer offset = 0 + reset_dt = data_type != 'merged_s2s' for p_i, p in enumerate(peaks): if p['area'] < min_area: @@ -158,10 +154,14 @@ def _split_peaks(split_finder, peaks, is_split, min_area, # Set the dt to the original (lowest) dt first; # this may change when the sum waveform of the new peak # is computed - r['dt'] = p['dt'] - r['length'] = (split_i - prev_split_i) - r['data'][:r['length']] = p['data'][prev_split_i: split_i] - r['area'] = p['data'][prev_split_i: split_i].sum() + if reset_dt: + r['dt'] = orig_dt + r['length'] = (split_i - prev_split_i) * p['dt'] / orig_dt + else: + r['dt'] = p['dt'] + r['length'] = (split_i - prev_split_i) + r['data'][:r['length']] = p['data'][prev_split_i: split_i] + r['area'] = p['data'][prev_split_i: split_i].sum() r['max_gap'] = -1 # Too lazy to compute this if r['length'] <= 0: print(p['data'])