From 6aca8a216614f8342540cbd20dce89a9dd3829c0 Mon Sep 17 00:00:00 2001 From: Ujjwal Panda Date: Thu, 10 Apr 2025 00:48:12 +0530 Subject: [PATCH 1/3] Move to `src` layout + format with `black`. --- .coveragerc | 5 +- Makefile | 2 +- docs/source/conf.py | 22 +- pyproject.toml | 4 +- riptide/pipeline/config_validation.py | 198 ------------- riptide/tests/__init__.py | 1 - riptide/tests/presto_generation.py | 58 ---- riptide/tests/run_tests.py | 11 - riptide/tests/test_pipeline.py | 169 ------------ riptide/tests/test_rseek.py | 68 ----- setup.py | 18 +- {riptide => src/riptide}/__init__.py | 44 ++- src/riptide/_version.py | 21 ++ {riptide => src/riptide}/apps/rseek.py | 132 +++++---- {riptide => src/riptide}/candidate.py | 115 ++++---- {riptide => src/riptide}/clustering.py | 13 +- {riptide => src/riptide}/cpp/README.md | 0 {riptide => src/riptide}/cpp/block.hpp | 0 {riptide => src/riptide}/cpp/downsample.hpp | 0 {riptide => src/riptide}/cpp/kernels.hpp | 0 {riptide => src/riptide}/cpp/periodogram.hpp | 0 .../riptide}/cpp/python_bindings.cpp | 0 .../riptide}/cpp/running_median.hpp | 0 {riptide => src/riptide}/cpp/snr.hpp | 0 {riptide => src/riptide}/cpp/transforms.hpp | 0 {riptide => src/riptide}/ffautils.py | 1 + {riptide => src/riptide}/folding.py | 16 +- {riptide => src/riptide}/libffa.py | 34 +-- {riptide => src/riptide}/metadata.py | 64 +++-- {riptide => src/riptide}/peak_detection.py | 82 +++--- {riptide => src/riptide}/periodogram.py | 47 ++-- {riptide => src/riptide}/pipeline/__init__.py | 0 .../riptide}/pipeline/config/example.yaml | 0 src/riptide/pipeline/config_validation.py | 259 ++++++++++++++++++ {riptide => src/riptide}/pipeline/dmiter.py | 120 +++++--- .../riptide}/pipeline/harmonic_testing.py | 63 +++-- .../riptide}/pipeline/peak_cluster.py | 44 +-- {riptide => src/riptide}/pipeline/pipeline.py | 227 ++++++++------- .../riptide}/pipeline/worker_pool.py | 31 +-- {riptide => src/riptide}/reading/__init__.py | 0 {riptide => src/riptide}/reading/presto.py | 95 +++---- {riptide => src/riptide}/reading/sigproc.py | 71 ++--- {riptide => src/riptide}/running_medians.py | 6 +- {riptide => src/riptide}/search.py | 35 ++- {riptide => src/riptide}/serialization.py | 79 +++--- {riptide => src/riptide}/time_series.py | 124 +++++---- {riptide => src/riptide}/timing.py | 5 +- {riptide/tests => tests}/data/README.md | 0 .../data/fake_presto_radio.dat | Bin .../data/fake_presto_radio.inf | 0 .../data/fake_presto_radio_breaks.dat | Bin .../data/fake_presto_radio_breaks.inf | 0 .../tests => tests}/data/fake_presto_xray.dat | Bin .../tests => tests}/data/fake_presto_xray.inf | 0 .../data/fake_sigproc_float32.tim | Bin .../data/fake_sigproc_int8.tim | Bin .../data/fake_sigproc_uint8.tim | Bin .../data/fake_sigproc_uint8_nosignedkey.tim | Bin .../tests => tests}/pipeline_config_A.yml | 0 .../tests => tests}/pipeline_config_B.yml | 0 tests/presto_generation.py | 1 + .../test_ffa_base_functions.py | 86 +++--- .../tests => tests}/test_ffa_search_pgram.py | 47 ++-- tests/test_pipeline.py | 246 +++++++++++++++++ tests/test_rseek.py | 145 ++++++++++ .../tests => tests}/test_running_median.py | 22 +- {riptide/tests => tests}/test_snr.py | 18 +- {riptide/tests => tests}/test_time_series.py | 66 +++-- 68 files changed, 1629 insertions(+), 1286 deletions(-) delete mode 100644 riptide/pipeline/config_validation.py delete mode 100644 riptide/tests/__init__.py delete mode 100644 riptide/tests/presto_generation.py delete mode 100644 riptide/tests/run_tests.py delete mode 100644 riptide/tests/test_pipeline.py delete mode 100644 riptide/tests/test_rseek.py rename {riptide => src/riptide}/__init__.py (55%) create mode 100644 src/riptide/_version.py rename {riptide => src/riptide}/apps/rseek.py (52%) rename {riptide => src/riptide}/candidate.py (68%) rename {riptide => src/riptide}/clustering.py (84%) rename {riptide => src/riptide}/cpp/README.md (100%) rename {riptide => src/riptide}/cpp/block.hpp (100%) rename {riptide => src/riptide}/cpp/downsample.hpp (100%) rename {riptide => src/riptide}/cpp/kernels.hpp (100%) rename {riptide => src/riptide}/cpp/periodogram.hpp (100%) rename {riptide => src/riptide}/cpp/python_bindings.cpp (100%) rename {riptide => src/riptide}/cpp/running_median.hpp (100%) rename {riptide => src/riptide}/cpp/snr.hpp (100%) rename {riptide => src/riptide}/cpp/transforms.hpp (100%) rename {riptide => src/riptide}/ffautils.py (99%) rename {riptide => src/riptide}/folding.py (88%) rename {riptide => src/riptide}/libffa.py (91%) rename {riptide => src/riptide}/metadata.py (64%) rename {riptide => src/riptide}/peak_detection.py (80%) rename {riptide => src/riptide}/periodogram.py (67%) rename {riptide => src/riptide}/pipeline/__init__.py (100%) rename {riptide => src/riptide}/pipeline/config/example.yaml (100%) create mode 100644 src/riptide/pipeline/config_validation.py rename {riptide => src/riptide}/pipeline/dmiter.py (76%) rename {riptide => src/riptide}/pipeline/harmonic_testing.py (85%) rename {riptide => src/riptide}/pipeline/peak_cluster.py (78%) rename {riptide => src/riptide}/pipeline/pipeline.py (72%) rename {riptide => src/riptide}/pipeline/worker_pool.py (73%) rename {riptide => src/riptide}/reading/__init__.py (100%) rename {riptide => src/riptide}/reading/presto.py (58%) rename {riptide => src/riptide}/reading/sigproc.py (71%) rename {riptide => src/riptide}/running_medians.py (97%) rename {riptide => src/riptide}/search.py (86%) rename {riptide => src/riptide}/serialization.py (66%) rename {riptide => src/riptide}/time_series.py (80%) rename {riptide => src/riptide}/timing.py (83%) rename {riptide/tests => tests}/data/README.md (100%) rename {riptide/tests => tests}/data/fake_presto_radio.dat (100%) rename {riptide/tests => tests}/data/fake_presto_radio.inf (100%) rename {riptide/tests => tests}/data/fake_presto_radio_breaks.dat (100%) rename {riptide/tests => tests}/data/fake_presto_radio_breaks.inf (100%) rename {riptide/tests => tests}/data/fake_presto_xray.dat (100%) rename {riptide/tests => tests}/data/fake_presto_xray.inf (100%) rename {riptide/tests => tests}/data/fake_sigproc_float32.tim (100%) rename {riptide/tests => tests}/data/fake_sigproc_int8.tim (100%) rename {riptide/tests => tests}/data/fake_sigproc_uint8.tim (100%) rename {riptide/tests => tests}/data/fake_sigproc_uint8_nosignedkey.tim (100%) rename {riptide/tests => tests}/pipeline_config_A.yml (100%) rename {riptide/tests => tests}/pipeline_config_B.yml (100%) create mode 100644 tests/presto_generation.py rename {riptide/tests => tests}/test_ffa_base_functions.py (58%) rename {riptide/tests => tests}/test_ffa_search_pgram.py (70%) create mode 100644 tests/test_pipeline.py create mode 100644 tests/test_rseek.py rename {riptide/tests => tests}/test_running_median.py (78%) rename {riptide/tests => tests}/test_snr.py (81%) rename {riptide/tests => tests}/test_time_series.py (80%) diff --git a/.coveragerc b/.coveragerc index d8b669b..b0a944a 100644 --- a/.coveragerc +++ b/.coveragerc @@ -2,11 +2,8 @@ branch = True parallel = True concurrency = multiprocessing -include = - riptide/* omit = - riptide/tests/* - riptide/_version.py + src/riptide/_version.py [report] exclude_lines = diff --git a/Makefile b/Makefile index dc6729c..af97a72 100644 --- a/Makefile +++ b/Makefile @@ -32,6 +32,6 @@ clean: ## Remove all python cache and build files rm -f .coverage tests: ## Run the unit tests and print a coverage report - pytest --cov --verbose --cov-report term-missing riptide/tests + pytest --cov --verbose --cov-report term-missing tests .PHONY: dist install uninstall help clean tests diff --git a/docs/source/conf.py b/docs/source/conf.py index 1c5923e..f88f8e0 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -16,9 +16,9 @@ # -- Project information ----------------------------------------------------- -project = 'riptide-ffa' -copyright = '2021, Vincent Morello' -author = 'Vincent Morello' +project = "riptide-ffa" +copyright = "2021, Vincent Morello" +author = "Vincent Morello" # -- General configuration --------------------------------------------------- @@ -29,14 +29,14 @@ # NOTE: sphinx.ext.autosectionlabel makes it easy to reference other sections in the docs # See: https://stackoverflow.com/a/54843636 extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.napoleon', - 'sphinx_rtd_theme', - 'sphinx.ext.autosectionlabel', + "sphinx.ext.autodoc", + "sphinx.ext.napoleon", + "sphinx_rtd_theme", + "sphinx.ext.autosectionlabel", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. @@ -49,11 +49,11 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] -pygments_style = 'sphinx' +pygments_style = "sphinx" diff --git a/pyproject.toml b/pyproject.toml index 928b4df..2d93d21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ dynamic = ["version"] license = {file = "LICENSE"} maintainers = [{name = "Vincent Morello", email = "vmorello@gmail.com"}] readme = "README.md" -requires-python = ">=3.6" +requires-python = ">=3.9" classifiers = [ "Programming Language :: Python :: 3 :: Only", @@ -57,4 +57,4 @@ homepage = "https://github.com/vmorello/riptide" documentation = "https://riptide-ffa.readthedocs.io" [tool.setuptools_scm] -write_to = "riptide/_version.py" +write_to = "src/riptide/_version.py" diff --git a/riptide/pipeline/config_validation.py b/riptide/pipeline/config_validation.py deleted file mode 100644 index 3d0a298..0000000 --- a/riptide/pipeline/config_validation.py +++ /dev/null @@ -1,198 +0,0 @@ -from schema import Schema, Use, Optional, And, Or - - -class InvalidSearchRange(Exception): - pass - - -class InvalidPipelineConfig(Exception): - pass - - -def strictly_positive(x): - return x > 0 - - -VALID_FORMATS = ('presto', 'sigproc') - - -SEARCH_RANGE_SCHEMA = Schema({ - 'name': str, - - 'ffa_search': { - 'period_min': And(Use(float), strictly_positive, error="period_min must be a number > 0"), - 'period_max': And(Use(float), strictly_positive, error="period_max must be a number > 0"), - 'bins_min': And(int, strictly_positive, error="bins_min must be an int > 0"), - 'bins_max': And(int, strictly_positive, error="bins_max must be an int > 0"), - Optional('fpmin'): And(int, strictly_positive, error="fpmin must be an int > 0"), - Optional('wtsp'): And( - Use(float), lambda x: x > 1, error="wtsp must be a number > 1"), - Optional('ducy_max'): And( - float, lambda x: 0 < x < 1, error='ducy_max must be strictly between 0 and 1'), - }, - - 'find_peaks': { - Optional('smin'): And( - Use(float), strictly_positive, error="smin must be a number > 0"), - Optional('segwidth'): And( - Use(float), strictly_positive, error="segwidth must be a number > 0"), - Optional('nstd'): And( - Use(float), strictly_positive, error="nstd must be a number > 0"), - Optional('minseg'): And( - int, strictly_positive, error="minseg must be an int > 0"), - Optional('polydeg'): And( - Use(float), strictly_positive, error="polydeg must be a number > 0"), - Optional('clrad'): Or( - And(Use(float), strictly_positive), None, error="clrad must be a number > 0"), - }, - - 'candidates': { - 'bins': And(int, strictly_positive, error="candidates.bins must be an int > 0"), - 'subints': And(int, strictly_positive, error="candidates.subints must be an int > 0"), - }, -}) - - -PIPELINE_CONFIG_SCHEMA = Schema({ - 'processes': And(int, strictly_positive, error="processes must be an int > 0"), - - 'data': { - 'format': Schema( - lambda x: x in VALID_FORMATS, - error=f"format must be one of {VALID_FORMATS}"), - 'fmin': Or( - And(Use(float), strictly_positive), None, - error="fmin must be a number > 0 or null/blank"), - 'fmax': Or( - And(Use(float), strictly_positive), None, - error="fmax must be a number > 0 or null/blank"), - 'nchans': Or( - And(int, strictly_positive), None, - error="nchans must be a number > 0 or null/blank"), - }, - - 'dmselect': { - 'min': Or(Use(float), None, error="Minimum DM must be a number or null/blank"), - 'max': Or(Use(float), None, error="Maximum DM must be a number or null/blank"), - 'dmsinb_max': Or( - strictly_positive, None, - error="dmsinb_max must be a number > 0 or null/blank"), - }, - - 'dereddening': { - 'rmed_width': Schema(strictly_positive, error="rmed_width must be a number > 0"), - 'rmed_minpts': Schema(strictly_positive, error="rmed_minpts must be a number > 0") - }, - - 'ranges': [SEARCH_RANGE_SCHEMA], - - 'clustering': { - 'radius': Schema(strictly_positive, error="clustering radius must be a number > 0"), - }, - - 'harmonic_flagging': { - 'denom_max': And(int, strictly_positive, error="denom_max must be an int > 0"), - 'phase_distance_max': And( - Use(float), strictly_positive, error="phase_distance_max must be a number > 0"), - 'dm_distance_max': And( - Use(float), strictly_positive, error="dm_distance_max must be a number > 0"), - 'snr_distance_max': And( - Use(float), strictly_positive, error="snr_distance_max must be a number > 0"), - }, - - 'candidate_filters': { - 'dm_min': Or(Use(float), None, error='Candidate dm_min must be a float or null/blank'), - 'snr_min': Or(Use(float), None, error='Candidate snr_min must be a float or null/blank'), - 'remove_harmonics': Or( - bool, None, error='remove_harmonics must be a boolean or null/blank'), - 'max_number': Or( - And(int, strictly_positive), None, - error='Candidate max_number must be an int > 0 or null/blank'), - }, - - 'plot_candidates': Schema(bool, error='plot_candidates must be a boolean'), -}) - - -def validate_range(rg, tsamp_max): - """ """ - # NOTE: In general, we leave the pipeline code to raise the exceptions, - # except if it takes too long for it to detect them; for example, if the number of candidate - # bins is too large, we don't want to wait until the candidate building stage to realize this. - period_min = rg['ffa_search']['period_min'] - period_max = rg['ffa_search']['period_max'] - bins_min = rg['ffa_search']['bins_min'] - cand_bins = rg['candidates']['bins'] - - if bins_min * tsamp_max > period_min: - raise InvalidSearchRange( - f"Search range {period_min:.3e} to {period_max:.3e} seconds: requested phase " - "resolution is too high w.r.t. coarsest input time series " - f"(tsamp = {tsamp_max:.3e} seconds). Use smaller bins_min or larger period_min.") - - if cand_bins * tsamp_max > period_min: - raise InvalidSearchRange( - f"Search range {period_min:.3e} to {period_max:.3e} seconds: " - f"cannot fold candidates with such high resolution ({cand_bins:d} bins). " - f"The coarsest input time series ({tsamp_max:.3e} seconds) does not allow it") - - -def validate_ranges_contiguity(ranges): - """ """ - for a, b in zip(ranges[:-1], ranges[1:]): - period_max_a = a['ffa_search']['period_max'] - period_min_b = b['ffa_search']['period_min'] - if not period_max_a == period_min_b: - raise InvalidSearchRange( - "Search ranges are not either non-contiguous, or not ordered by increasing trial " - f"period (period_max ({period_max_a:.6e}) != next period_min ({period_min_b:.6e})") - - -def validate_ranges(ranges, tsamp_max): - """ - Check that the search ranges are valid. Raise an exception if not. - - Parameters - ---------- - ranges : list of dict - Search ranges read from the pipeline configuration file - tsamp_max : float - Maximum sampling interval of the TimeSeries to process - - Raises - ------ - InvalidSearchRange - """ - for rg in ranges: - validate_range(rg, tsamp_max) - validate_ranges_contiguity(ranges) - - -def validate_pipeline_config(conf): - """ - Validate pipeline configuration dictionary and raise an error if it is - incorrect. This function only checks the format of the config and - the data types. - - Parameters - ---------- - conf : dict - Configuration dictionary loaded from the pipeline config file - - Returns - ------- - validated : dict - Validated configuration dictionary. Some data types may have been - changed (e.g. into to float, or float to int when both are allowed - for a config parameter). - - Raises - ------ - InvalidPipelineConfig - """ - try: - validated = PIPELINE_CONFIG_SCHEMA.validate(conf) - except Exception as ex: - # Suppress long and confusing exception chain caused by schema library - raise InvalidPipelineConfig(str(ex)) from None - return validated diff --git a/riptide/tests/__init__.py b/riptide/tests/__init__.py deleted file mode 100644 index e3f3aa9..0000000 --- a/riptide/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .run_tests import test \ No newline at end of file diff --git a/riptide/tests/presto_generation.py b/riptide/tests/presto_generation.py deleted file mode 100644 index 2c0f9e5..0000000 --- a/riptide/tests/presto_generation.py +++ /dev/null @@ -1,58 +0,0 @@ -import os -import numpy as np -from riptide import TimeSeries - - -INF_TEMPLATE = """ - Data file name without suffix = {basename:s} - Telescope used = Parkes - Instrument used = Multibeam - Object being observed = Pulsar - J2000 Right Ascension (hh:mm:ss.ssss) = 00:00:01.0000 - J2000 Declination (dd:mm:ss.ssss) = -00:00:01.0000 - Data observed by = Kenji Oba - Epoch of observation (MJD) = 59000.000000 - Barycentered? (1=yes, 0=no) = 1 - Number of bins in the time series = {nsamp:d} - Width of each time series bin (sec) = {tsamp:.12e} - Any breaks in the data? (1=yes, 0=no) = 0 - Type of observation (EM band) = Radio - Beam diameter (arcsec) = 981 - Dispersion measure (cm-3 pc) = {dm:.12f} - Central freq of low channel (Mhz) = 1182.1953125 - Total bandwidth (Mhz) = 400 - Number of channels = 1024 - Channel bandwidth (Mhz) = 0.390625 - Data analyzed by = Space Sheriff Gavan - Any additional notes: - Input filterbank samples have 2 bits. -""" - -def generate_data_presto(outdir, basename, tobs=128.0, tsamp=256e-6, period=1.0, dm=0.0, amplitude=20.0, ducy=0.05): - """ - Generate some time series data with a fake signal, and save it in PRESTO - inf/dat format in the specified output directory. - - Parameters - ---------- - outdir : str - Path to the output directory - basename : str - Base file name (not path) under which the .inf and .dat files - will be saved. - **kwargs: self-explanatory - """ - ### IMPORTANT: seed the RNG to get reproducible results ### - np.random.seed(0) - - ts = TimeSeries.generate( - tobs, tsamp, period, - amplitude=amplitude, ducy=ducy, stdnoise=1.0 - ) - inf_text = INF_TEMPLATE.format(basename=basename, nsamp=ts.nsamp, tsamp=tsamp, dm=dm) - - inf_path = os.path.join(outdir, f"{basename}.inf") - dat_path = os.path.join(outdir, f"{basename}.dat") - with open(inf_path, 'w') as fobj: - fobj.write(inf_text) - ts.data.tofile(dat_path) diff --git a/riptide/tests/run_tests.py b/riptide/tests/run_tests.py deleted file mode 100644 index b3e3a67..0000000 --- a/riptide/tests/run_tests.py +++ /dev/null @@ -1,11 +0,0 @@ -import os -import pytest - - -def test(): - """ - Run the test suite. - """ - tests_dir = os.path.dirname(__file__) - args = ['--verbose', tests_dir] - pytest.main(args) \ No newline at end of file diff --git a/riptide/tests/test_pipeline.py b/riptide/tests/test_pipeline.py deleted file mode 100644 index 0493df1..0000000 --- a/riptide/tests/test_pipeline.py +++ /dev/null @@ -1,169 +0,0 @@ -import os -import glob -import tempfile -from copy import deepcopy - -from pytest import raises -import yaml -import numpy as np -from riptide import TimeSeries, load_json -from riptide.pipeline.pipeline import get_parser, run_program -from riptide.pipeline.config_validation import InvalidPipelineConfig, InvalidSearchRange -from .presto_generation import generate_data_presto - -# NOTE 1: -# pipeline uses multiprocessing, to get proper coverage stats we need: -# * A .coveragerc file with the following options: -# [run] -# branch = True -# parallel = True -# concurrency = multiprocessing -# * Ensure that all instances of multiprocessing.Pool() have been closed and joined, as follows: -# >> pool.close() -# >> pool.join() - -# NOTE 2: -# To print logging output in full, call pytest like this: -# py.test --capture=no -o log_cli=True - -# NOTE 3: -# To get coverage stats, run this in the base riptide directory: -# coverage run -m pytest && coverage combine && coverage report -m --omit riptide/_version.py - - -SIGNAL_PERIOD = 1.0 -DATA_TOBS = 128.0 -DATA_TSAMP = 256e-6 - - -def runner_presto_fakepsr(fname_conf, outdir): - # Write test data - # NOTE: generate a signal bright enough to get harmonics and thus make sure - # that the harmonic filter gets to run - params = [ - # (dm, amplitude, ducy) - (0.0 , 10.0, 0.05), - (10.0, 20.0, 0.02), - (20.0, 10.0, 0.05) - ] - - for dm, amplitude, ducy in params: - basename = f"fake_DM{dm:.3f}" - generate_data_presto( - outdir, basename, tobs=DATA_TOBS, tsamp=DATA_TSAMP, period=SIGNAL_PERIOD, - dm=dm, amplitude=amplitude, ducy=ducy - ) - - ### Run pipeline ### - files = glob.glob(f'{outdir}/*.inf') - cmdline_args = ['--config', fname_conf, '--outdir', outdir] + files - parser = get_parser() - args = parser.parse_args(cmdline_args) - run_program(args) - - ### Check output sanity ### - topcand_fname = f"{outdir}/candidate_0000.json" - assert os.path.isfile(topcand_fname) - - topcand = load_json(topcand_fname) - - # NOTE: these checks depend on the RNG seed and the pipeline config - assert abs(topcand.params['period'] - SIGNAL_PERIOD) < 1.00e-4 - assert topcand.params['dm'] == 10.0 - assert topcand.params['width'] == 13 - assert abs(topcand.params['snr'] - 18.5) < 0.15 - - -def runner_presto_purenoise(fname_conf, outdir): - """ - Check that pipeline runs well even if no candidates are found - """ - dm = 0.0 - basename = f"purenoise_DM{dm:.3f}" - generate_data_presto( - outdir, basename, tobs=DATA_TOBS, tsamp=DATA_TSAMP, period=SIGNAL_PERIOD, - dm=dm, amplitude=0.0 - ) - - ### Run pipeline ### - files = glob.glob(f'{outdir}/*.inf') - cmdline_args = ['--config', fname_conf, '--outdir', outdir] + files - parser = get_parser() - args = parser.parse_args(cmdline_args) - run_program(args) - - ### Check output sanity ### - assert not glob.glob(f"{outdir}/*.json") - assert not glob.glob(f"{outdir}/*.png") - - -def load_yaml(fname): - with open(fname, 'r') as fobj: - return yaml.safe_load(fobj) - - -def save_yaml(items, fname): - with open(fname, 'w') as fobj: - return yaml.safe_dump(items, fobj) - - -def test_pipeline_presto_fakepsr(): - # NOTE: outdir is a full path (str) - with tempfile.TemporaryDirectory() as outdir: - fname_conf = os.path.join(os.path.dirname(__file__), 'pipeline_config_A.yml') - runner_presto_fakepsr(fname_conf, outdir) - - with tempfile.TemporaryDirectory() as outdir: - fname_conf = os.path.join(os.path.dirname(__file__), 'pipeline_config_B.yml') - runner_presto_fakepsr(fname_conf, outdir) - - -def test_pipeline_presto_purenoise(): - with tempfile.TemporaryDirectory() as outdir: - fname_conf = os.path.join(os.path.dirname(__file__), 'pipeline_config_A.yml') - runner_presto_purenoise(fname_conf, outdir) - - with tempfile.TemporaryDirectory() as outdir: - fname_conf = os.path.join(os.path.dirname(__file__), 'pipeline_config_B.yml') - runner_presto_purenoise(fname_conf, outdir) - - -def test_config_validation(): - fname_conf = os.path.join(os.path.dirname(__file__), 'pipeline_config_A.yml') - conf_correct = load_yaml(fname_conf) - - # Wrong parameter type - with tempfile.TemporaryDirectory() as outdir: - conf_wrong = deepcopy(conf_correct) - conf_wrong['dmselect']['min'] = 'LOL' - tmp = os.path.join(outdir, 'wrong_config.yaml') - save_yaml(conf_wrong, tmp) - with raises(InvalidPipelineConfig): - runner_presto_fakepsr(tmp, outdir) - - # period_min too low - with tempfile.TemporaryDirectory() as outdir: - conf_wrong = deepcopy(conf_correct) - conf_wrong['ranges'][0]['ffa_search']['period_min'] = 1.0e-9 - tmp = os.path.join(outdir, 'wrong_config.yaml') - save_yaml(conf_wrong, tmp) - with raises(InvalidSearchRange): - runner_presto_fakepsr(tmp, outdir) - - # too many phase bins requested to fold candidates - with tempfile.TemporaryDirectory() as outdir: - conf_wrong = deepcopy(conf_correct) - conf_wrong['ranges'][0]['candidates']['bins'] = int(42.0e9) - tmp = os.path.join(outdir, 'wrong_config.yaml') - save_yaml(conf_wrong, tmp) - with raises(InvalidSearchRange): - runner_presto_fakepsr(tmp, outdir) - - # non-contiguous search ranges - with tempfile.TemporaryDirectory() as outdir: - conf_wrong = deepcopy(conf_correct) - conf_wrong['ranges'][0]['ffa_search']['period_max'] = 0.50042 - tmp = os.path.join(outdir, 'wrong_config.yaml') - save_yaml(conf_wrong, tmp) - with raises(InvalidSearchRange): - runner_presto_fakepsr(tmp, outdir) \ No newline at end of file diff --git a/riptide/tests/test_rseek.py b/riptide/tests/test_rseek.py deleted file mode 100644 index 1a9f078..0000000 --- a/riptide/tests/test_rseek.py +++ /dev/null @@ -1,68 +0,0 @@ -import os -import tempfile - -import numpy as np -from pytest import raises -from riptide.apps.rseek import get_parser, run_program -from .presto_generation import generate_data_presto - - -SIGNAL_PERIOD = 1.0 -SIGNAL_FREQ = 1.0 / SIGNAL_PERIOD -DATA_TOBS = 128.0 -DATA_TSAMP = 256e-6 - -PARSER = get_parser() -EXPECTED_COLUMNS = {'period', 'freq', 'width', 'ducy', 'dm', 'snr'} -DEFAULT_OPTIONS = dict(Pmin=0.5, Pmax=2.0, bmin=480, bmax=520, smin=7.0, format='presto') - - -def dict2args(d): - """ - Convert dictionary of options to command line argument list - """ - args = [] - for k, v in d.items(): - args.append(f'--{k}') - args.append(str(v)) - return args - - -def test_rseek_fakepsr(): - with tempfile.TemporaryDirectory() as outdir: - generate_data_presto( - outdir, 'data', tobs=DATA_TOBS, tsamp=DATA_TSAMP, period=SIGNAL_PERIOD, - dm=0.0, amplitude=20.0, ducy=0.02 - ) - fname = os.path.join(outdir, 'data.inf') - cmdline_args = dict2args(DEFAULT_OPTIONS) + [fname] - args = PARSER.parse_args(cmdline_args) - df = run_program(args) - - assert df is not None - assert set(df.columns) == EXPECTED_COLUMNS - - # Results must be sorted in decreasing S/N order - assert all(df.snr == df.sort_values('snr', ascending=False).snr) - - # Check parameters of the top candidate - # NOTE: these checks depend on the RNG seed and the program options - topcand = df.iloc[0] - assert abs(topcand.freq - SIGNAL_FREQ) < 0.1 / DATA_TOBS - assert abs(topcand.snr - 18.5) < 0.15 - assert topcand.dm == 0 - assert topcand.width == 13 - - -def test_rseek_purenoise(): - with tempfile.TemporaryDirectory() as outdir: - generate_data_presto( - outdir, 'data', tobs=DATA_TOBS, tsamp=DATA_TSAMP, period=SIGNAL_PERIOD, - dm=0.0, amplitude=0.0 - ) - fname = os.path.join(outdir, 'data.inf') - cmdline_args = dict2args(DEFAULT_OPTIONS) + [fname] - args = PARSER.parse_args(cmdline_args) - df = run_program(args) - - assert df is None diff --git a/setup.py b/setup.py index 517ed8e..d9e7190 100644 --- a/setup.py +++ b/setup.py @@ -11,12 +11,12 @@ # end of these arrays). # The flags below provide the same speedups as -ffast-math, without the risks. SAFE_FAST_MATH_FLAGS = [ - '-fassociative-math', - '-fno-math-errno', - '-ffinite-math-only', - '-fno-rounding-math', - '-fno-signed-zeros', - '-fno-trapping-math', + "-fassociative-math", + "-fno-math-errno", + "-ffinite-math-only", + "-fno-rounding-math", + "-fno-signed-zeros", + "-fno-trapping-math", ] # The main interface is through Pybind11Extension. @@ -29,9 +29,9 @@ # reproducible builds (https://github.com/pybind/python_example/pull/53) ext_modules = [ Pybind11Extension( - 'riptide.libcpp', - sorted(['riptide/cpp/python_bindings.cpp']), - extra_compile_args=['-O3', '-march=native'] + SAFE_FAST_MATH_FLAGS + "riptide.libcpp", + sorted(["src/riptide/cpp/python_bindings.cpp"]), + extra_compile_args=["-O3", "-march=native"] + SAFE_FAST_MATH_FLAGS, ), ] diff --git a/riptide/__init__.py b/src/riptide/__init__.py similarity index 55% rename from riptide/__init__.py rename to src/riptide/__init__.py index 2e9ef3f..4a5a028 100644 --- a/riptide/__init__.py +++ b/src/riptide/__init__.py @@ -12,37 +12,27 @@ from .search import ffa_search from .running_medians import running_median, fast_running_median -from .libffa import ( - ffa1, - ffa2, - ffafreq, - ffaprd, - generate_signal, - downsample, - boxcar_snr) +from .libffa import ffa1, ffa2, ffafreq, ffaprd, generate_signal, downsample, boxcar_snr from .peak_detection import find_peaks ### Serialization from .serialization import save_json, load_json -from .tests import test - __all__ = [ - 'TimeSeries', - 'Periodogram', - 'Metadata', - 'Candidate', - 'ffa_search', - 'ffa1', - 'ffa2', - 'ffafreq', - 'ffaprd', - 'generate_signal', - 'downsample', - 'boxcar_snr', - 'find_peaks', - 'save_json', - 'load_json', - 'test' - ] + "TimeSeries", + "Periodogram", + "Metadata", + "Candidate", + "ffa_search", + "ffa1", + "ffa2", + "ffafreq", + "ffaprd", + "generate_signal", + "downsample", + "boxcar_snr", + "find_peaks", + "save_json", + "load_json", +] diff --git a/src/riptide/_version.py b/src/riptide/_version.py new file mode 100644 index 0000000..065144e --- /dev/null +++ b/src/riptide/_version.py @@ -0,0 +1,21 @@ +# file generated by setuptools-scm +# don't change, don't track in version control + +__all__ = ["__version__", "__version_tuple__", "version", "version_tuple"] + +TYPE_CHECKING = False +if TYPE_CHECKING: + from typing import Tuple + from typing import Union + + VERSION_TUPLE = Tuple[Union[int, str], ...] +else: + VERSION_TUPLE = object + +version: str +__version__: str +__version_tuple__: VERSION_TUPLE +version_tuple: VERSION_TUPLE + +__version__ = version = "0.2.6.dev5+g054e998.d20250409" +__version_tuple__ = version_tuple = (0, 2, 6, "dev5", "g054e998.d20250409") diff --git a/riptide/apps/rseek.py b/src/riptide/apps/rseek.py similarity index 52% rename from riptide/apps/rseek.py rename to src/riptide/apps/rseek.py index 91ce8b8..c211242 100644 --- a/riptide/apps/rseek.py +++ b/src/riptide/apps/rseek.py @@ -8,83 +8,96 @@ from riptide.clustering import cluster1d -log = logging.getLogger('riptide.rseek') -help_formatter = lambda prog: argparse.ArgumentDefaultsHelpFormatter(prog, max_help_position=16) +log = logging.getLogger("riptide.rseek") +help_formatter = lambda prog: argparse.ArgumentDefaultsHelpFormatter( + prog, max_help_position=16 +) + - def get_parser(): parser = argparse.ArgumentParser( formatter_class=help_formatter, description=( - "FFA search a single time series and print a table of parameters of all significant peaks found." - " Peaks found with nearly identical periods at different trial pulse widths are grouped," - " but no harmonic filtering is performed." - ) + "FFA search a single time series and print a table of parameters of all significant peaks found." + " Peaks found with nearly identical periods at different trial pulse widths are grouped," + " but no harmonic filtering is performed." + ), ) parser.add_argument( - "-f", "--format", type=str, choices=('presto', 'sigproc'), required=True, - help="Input TimeSeries format" + "-f", + "--format", + type=str, + choices=("presto", "sigproc"), + required=True, + help="Input TimeSeries format", ) parser.add_argument( - "--Pmin", type=float, default=1.0, - help="Minimum trial period in seconds" + "--Pmin", type=float, default=1.0, help="Minimum trial period in seconds" ) parser.add_argument( - "--Pmax", type=float, default=10.0, - help="Maximum trial period in seconds" + "--Pmax", type=float, default=10.0, help="Maximum trial period in seconds" ) parser.add_argument( - "--bmin", type=int, default=240, - help="Minimum number of phase bins used in the search" + "--bmin", + type=int, + default=240, + help="Minimum number of phase bins used in the search", ) parser.add_argument( - "--bmax", type=int, default=260, - help="Maximum number of phase bins used in the search" + "--bmax", + type=int, + default=260, + help="Maximum number of phase bins used in the search", ) parser.add_argument( - "--smin", type=float, default=7.0, - help="Only report peaks above this minimum S/N" + "--smin", + type=float, + default=7.0, + help="Only report peaks above this minimum S/N", ) parser.add_argument( - "--wtsp", type=float, default=1.5, - help="Geometric factor between consecutive trial pulse widths" + "--wtsp", + type=float, + default=1.5, + help="Geometric factor between consecutive trial pulse widths", ) parser.add_argument( - "--rmed_width", type=float, default=4.0, - help="Width (in seconds) of the running median filter to subtract from the input data before processing" + "--rmed_width", + type=float, + default=4.0, + help="Width (in seconds) of the running median filter to subtract from the input data before processing", ) parser.add_argument( - "--rmed_minpts", type=float, default=101, + "--rmed_minpts", + type=float, + default=101, help=( - "The running median is calculated of a time scrunched version of the" - " input data to save time: rmed_minpts is the minimum number of" - " scrunched samples that must fit in the running median window" - " Lower values make the running median calculation less accurate but" - " faster, due to allowing a higher scrunching factor" - ) + "The running median is calculated of a time scrunched version of the" + " input data to save time: rmed_minpts is the minimum number of" + " scrunched samples that must fit in the running median window" + " Lower values make the running median calculation less accurate but" + " faster, due to allowing a higher scrunching factor" + ), ) parser.add_argument( - "--clrad", type=float, default=0.2, + "--clrad", + type=float, + default=0.2, help=( "Frequency clustering radius in units of 1/Tobs. Peaks with similar" " freqs are grouped together, and only the brightest one of the group" " is printed" - ) - ) - parser.add_argument( - "fname", type=str, - help="Input file name" + ), ) - parser.add_argument( - '--version', action='version', version=__version__ - ) + parser.add_argument("fname", type=str, help="Input file name") + parser.add_argument("--version", action="version", version=__version__) return parser def run_program(args): """ - Run the rseek program and return a pandas DataFrame with the detected peak - parameters, or None if no significant peaks were found. This is used to + Run the rseek program and return a pandas DataFrame with the detected peak + parameters, or None if no significant peaks were found. This is used to check the results in unit tests. Parameters @@ -98,20 +111,19 @@ def run_program(args): DataFrame with columns: 'period', 'freq', 'width', 'ducy', 'dm', 'snr' """ logging.basicConfig( - level='DEBUG', - format='%(asctime)s %(filename)18s:%(lineno)-4s %(levelname)-8s %(message)s' + level="DEBUG", + format="%(asctime)s %(filename)18s:%(lineno)-4s %(levelname)-8s %(message)s", ) - LOADERS = { - 'sigproc': TimeSeries.from_sigproc, - 'presto' : TimeSeries.from_presto_inf - } + LOADERS = {"sigproc": TimeSeries.from_sigproc, "presto": TimeSeries.from_presto_inf} loader = LOADERS[args.format] # Search and find peaks ts = loader(args.fname) - log.debug(f"Searching period range [{args.Pmin}, {args.Pmax}] seconds with {args.bmin} to {args.bmax} phase bins") + log.debug( + f"Searching period range [{args.Pmin}, {args.Pmax}] seconds with {args.bmin} to {args.bmax} phase bins" + ) __, pgram = ffa_search( ts, period_min=args.Pmin, @@ -121,8 +133,8 @@ def run_program(args): rmed_width=args.rmed_width, rmed_minpts=args.rmed_minpts, wtsp=args.wtsp, - fpmin=1, # No dynamic cap on period_max - ducy_max=0.3 + fpmin=1, # No dynamic cap on period_max + ducy_max=0.3, ) peaks, __ = find_peaks(pgram, smin=args.smin, clrad=args.clrad) @@ -133,7 +145,7 @@ def run_program(args): # Cluster peaks, i.e. for each period keep only the trial width # that yield the highest S/N freqs = np.asarray([p.freq for p in peaks]) - cluster_indices = cluster1d(freqs, r=args.clrad/ts.length) + cluster_indices = cluster1d(freqs, r=args.clrad / ts.length) peaks = [ max([peaks[ii] for ii in indices], key=lambda p: p.snr) for indices in cluster_indices @@ -142,23 +154,23 @@ def run_program(args): # DataFrame constructs from namedtuples nicely df = pandas.DataFrame(peaks) - df = df.drop(columns=['iw', 'ip']) + df = df.drop(columns=["iw", "ip"]) # Print this in a pleasing way # https://stackoverflow.com/questions/20937538/how-to-display-pandas-dataframe-of-floats-using-a-format-string-for-columns # NOTE: we have inserted a leading space to each format string on purpose # This makes the output table more readable formatters = { - 'period': ' {:.9f}'.format, - 'freq': ' {:.9f}'.format, - 'ducy': lambda x: ' {:#.2f}%'.format(100 * x), - 'dm': ' {:.2f}'.format, - 'snr': ' {:.1f}'.format, + "period": " {:.9f}".format, + "freq": " {:.9f}".format, + "ducy": lambda x: " {:#.2f}%".format(100 * x), + "dm": " {:.2f}".format, + "snr": " {:.1f}".format, } output = df.to_string( - columns=['period', 'freq', 'width', 'ducy', 'dm', 'snr'], + columns=["period", "freq", "width", "ducy", "dm", "snr"], formatters=formatters, - index=False + index=False, ) print(output) return df @@ -173,4 +185,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/riptide/candidate.py b/src/riptide/candidate.py similarity index 68% rename from riptide/candidate.py rename to src/riptide/candidate.py index 56fbe0e..c8f34b3 100644 --- a/riptide/candidate.py +++ b/src/riptide/candidate.py @@ -8,7 +8,7 @@ from astropy.time import Time -log = logging.getLogger('riptide.candidate') +log = logging.getLogger("riptide.candidate") class Candidate(object): @@ -42,6 +42,7 @@ class Candidate(object): Tuple of numpy arrays (dm, snr) containing respectively the sequence of DM trials, and corresponding best S/N value across all trial widths """ + def __init__(self, params, tsmeta, peaks, subints): self.params = params self.tsmeta = tsmeta @@ -49,12 +50,12 @@ def __init__(self, params, tsmeta, peaks, subints): self.subints = subints def to_dict(self): - """ Convert to dictionary for serialization """ + """Convert to dictionary for serialization""" return { - 'params': self.params, - 'tsmeta': self.tsmeta, - 'peaks': self.peaks, - 'subints': self.subints + "params": self.params, + "tsmeta": self.tsmeta, + "peaks": self.peaks, + "subints": self.subints, } @property @@ -68,15 +69,15 @@ def dm_curve(self): # NOTE: copy() works around a bug in pandas 0.23.x and earlier # https://stackoverflow.com/questions/53985535/pandas-valueerror-buffer-source-array-is-read-only # TODO: consider requiring pandas 0.24+ in the future - df = self.peaks.copy().groupby('dm').max() + df = self.peaks.copy().groupby("dm").max() return df.index.values, df.snr.values @classmethod def from_pipeline_output(cls, ts, peak_cluster, bins, subints=1): """ Method used by the pipeline to produce a candidate from intermediate - data products. - + data products. + subints can be an int or None. None means pick the number of subints that fit inside the data. @@ -95,12 +96,17 @@ def from_pipeline_output(cls, ts, peak_cluster, bins, subints=1): subints = None subints_array = ts.fold(centre.period, bins, subints=subints) - return cls(centre.summary_dict(), ts.metadata, peak_cluster.summary_dataframe(), subints_array) + return cls( + centre.summary_dict(), + ts.metadata, + peak_cluster.summary_dataframe(), + subints_array, + ) @classmethod def from_dict(cls, items): - """ De-serialize from dictionary """ - return cls(items['params'], items['tsmeta'], items['peaks'], items['subints']) + """De-serialize from dictionary""" + return cls(items["params"], items["tsmeta"], items["peaks"], items["subints"]) def plot(self, figsize=(18, 4.5), dpi=80): """ @@ -131,7 +137,7 @@ def show(self, **kwargs): def savefig(self, fname, **kwargs): """ - Create a plot of the candidate and save it as PNG under the specified + Create a plot of the candidate and save it as PNG under the specified file name. Accepts the same keyword arguments as plot(). """ fig = self.plot(**kwargs) @@ -146,7 +152,7 @@ def __repr__(self): return str(self) -TableEntryBase = namedtuple('TableEntry', ['name', 'value', 'formatter', 'unit']) +TableEntryBase = namedtuple("TableEntry", ["name", "value", "formatter", "unit"]) class TableEntry(TableEntryBase): @@ -157,61 +163,64 @@ def plot(self, X, y, **kwargs): y : float Y coordinate of the line """ - assert(len(X) == 3) + assert len(X) == 3 fmt = "{{:{}}}".format(self.formatter) plt.text(X[0], y, self.name, **kwargs) - plt.text(X[1], y, fmt.format(self.value), ha='right', **kwargs) + plt.text(X[1], y, fmt.format(self.value), ha="right", **kwargs) plt.text(X[2], y, self.unit, **kwargs) def plot_table(params, tsmeta): - """ - """ - plt.axis('off') - coord = tsmeta['skycoord'] - ra_hms = coord.ra.to_string(unit=uu.hour, sep=':', precision=2, pad=True) - dec_hms = coord.dec.to_string(unit=uu.deg, sep=':', precision=2, pad=True) + """ """ + plt.axis("off") + coord = tsmeta["skycoord"] + ra_hms = coord.ra.to_string(unit=uu.hour, sep=":", precision=2, pad=True) + dec_hms = coord.dec.to_string(unit=uu.deg, sep=":", precision=2, pad=True) # TODO: Check that the scale is actually UTC in the general case # PRESTO, SIGPROC and other packages may not have the same # date/time standard - obsdate = Time(tsmeta['mjd'], format='mjd', scale='utc', precision=0) + obsdate = Time(tsmeta["mjd"], format="mjd", scale="utc", precision=0) - blank = TableEntry(name='', value='', formatter='s', unit='') + blank = TableEntry(name="", value="", formatter="s", unit="") entries = [ - TableEntry(name='Period', value=params['period'] * 1000.0, formatter='.6f', unit='ms'), - TableEntry(name='DM', value=params['dm'], formatter='.2f', unit='pc cm$^{-3}$'), - TableEntry(name='Width', value=params['width'], formatter='d', unit='bins'), - TableEntry(name='Duty cycle', value=params['ducy'] * 100.0, formatter='.2f', unit='%'), - TableEntry(name='S/N', value=params['snr'], formatter='.1f', unit=''), + TableEntry( + name="Period", value=params["period"] * 1000.0, formatter=".6f", unit="ms" + ), + TableEntry(name="DM", value=params["dm"], formatter=".2f", unit="pc cm$^{-3}$"), + TableEntry(name="Width", value=params["width"], formatter="d", unit="bins"), + TableEntry( + name="Duty cycle", value=params["ducy"] * 100.0, formatter=".2f", unit="%" + ), + TableEntry(name="S/N", value=params["snr"], formatter=".1f", unit=""), blank, - TableEntry(name='Source', value=tsmeta['source_name'], formatter='s', unit=''), - TableEntry(name='RA', value=ra_hms, formatter='s', unit=''), - TableEntry(name='Dec', value=dec_hms, formatter='s', unit=''), - TableEntry(name='MJD', value=obsdate.mjd, formatter='.6f', unit=''), - TableEntry(name='UTC', value=obsdate.iso, formatter='s', unit=''), + TableEntry(name="Source", value=tsmeta["source_name"], formatter="s", unit=""), + TableEntry(name="RA", value=ra_hms, formatter="s", unit=""), + TableEntry(name="Dec", value=dec_hms, formatter="s", unit=""), + TableEntry(name="MJD", value=obsdate.mjd, formatter=".6f", unit=""), + TableEntry(name="UTC", value=obsdate.iso, formatter="s", unit=""), ] y0 = 0.94 # Y coordinate of first line dy = 0.105 # line height - X = [0.0, 0.80, 0.84] # Coordinate of columns name, value, unit + X = [0.0, 0.80, 0.84] # Coordinate of columns name, value, unit for ii, entry in enumerate(entries): - entry.plot(X, y0 - ii * dy, family='monospace') + entry.plot(X, y0 - ii * dy, family="monospace") def plot_dm_curve(dm, snr): dm_min = dm.min() dm_max = dm.max() - plt.plot(dm, snr, color='r', marker='o', markersize=3) + plt.plot(dm, snr, color="r", marker="o", markersize=3) - # Avoid matplotlib warning when calling xlim() with two equal values + # Avoid matplotlib warning when calling xlim() with two equal values if dm_min == dm_max: plt.xlim(dm_min - 0.5, dm_min + 0.5) else: plt.xlim(dm_min, dm_max) - plt.grid(linestyle=':') + plt.grid(linestyle=":") plt.xlabel("DM (pc cm$^{-3}$)") plt.ylabel("Best S/N") @@ -225,16 +234,18 @@ def plot_subints(X, T): """ __, nbins = X.shape - X = np.hstack((X, X[:, :nbins//2])) + X = np.hstack((X, X[:, : nbins // 2])) __, nbins_ext = X.shape plt.imshow( - X, - cmap='Greys', interpolation='nearest', aspect='auto', - extent=[-0.5, nbins_ext-0.5, T, 0] # Note: t = 0 is at the top of the plot - ) - plt.fill_between([nbins, nbins_ext], [0, 0], [T, T], color='b', alpha=0.08) - plt.xlim(-0.5, nbins_ext-0.5) + X, + cmap="Greys", + interpolation="nearest", + aspect="auto", + extent=[-0.5, nbins_ext - 0.5, T, 0], # Note: t = 0 is at the top of the plot + ) + plt.fill_between([nbins, nbins_ext], [0, 0], [T, T], color="b", alpha=0.08) + plt.xlim(-0.5, nbins_ext - 0.5) plt.ylabel("Time (seconds)") plt.title("1.5 Periods of Signal") @@ -244,16 +255,18 @@ def plot_profile(P): P : profile normalised to unit background noise variance """ nbins = len(P) - P = np.concatenate((P, P[:nbins//2])) + P = np.concatenate((P, P[: nbins // 2])) nbins_ext = len(P) - plt.bar(range(nbins_ext), P - np.median(P), width=1, color='#404040') + plt.bar(range(nbins_ext), P - np.median(P), width=1, color="#404040") ymin, ymax = plt.ylim() - plt.fill_between([nbins, nbins_ext], [ymin, ymin], [ymax, ymax], color='b', alpha=0.08) + plt.fill_between( + [nbins, nbins_ext], [ymin, ymin], [ymax, ymax], color="b", alpha=0.08 + ) plt.ylim(ymin, ymax) - plt.xlim(-0.5, nbins_ext-0.5) + plt.xlim(-0.5, nbins_ext - 0.5) plt.xlabel("Phase bin") plt.ylabel("Normalised amplitude") @@ -267,7 +280,7 @@ def plot_candidate(cand): gs = GridSpec(nrows, ncols, figure=plt.gcf()) plt.subplot(gs[:1, 2:]) - plot_subints(cand.subints, cand.tsmeta['tobs']) + plot_subints(cand.subints, cand.tsmeta["tobs"]) plt.subplot(gs[1:, 2:]) plot_profile(cand.profile) diff --git a/riptide/clustering.py b/src/riptide/clustering.py similarity index 84% rename from riptide/clustering.py rename to src/riptide/clustering.py index c48df00..c96aa95 100644 --- a/riptide/clustering.py +++ b/src/riptide/clustering.py @@ -34,7 +34,7 @@ def cluster1d(x, r, already_sorted=False): # NOTE: diff is the sequence of consecutive differences # of x AFTER it has been sorted - # Indices + # Indices ibreaks = np.where(abs(diff) > r)[0] # In this case, there is only one cluster @@ -42,10 +42,7 @@ def cluster1d(x, r, already_sorted=False): return [indices] # Cluster bounds indices - ibounds = np.concatenate(([0], ibreaks+1, [len(x)])) - - clusters = [ - indices[start:end] - for start, end in zip(ibounds[:-1], ibounds[1:]) - ] - return clusters \ No newline at end of file + ibounds = np.concatenate(([0], ibreaks + 1, [len(x)])) + + clusters = [indices[start:end] for start, end in zip(ibounds[:-1], ibounds[1:])] + return clusters diff --git a/riptide/cpp/README.md b/src/riptide/cpp/README.md similarity index 100% rename from riptide/cpp/README.md rename to src/riptide/cpp/README.md diff --git a/riptide/cpp/block.hpp b/src/riptide/cpp/block.hpp similarity index 100% rename from riptide/cpp/block.hpp rename to src/riptide/cpp/block.hpp diff --git a/riptide/cpp/downsample.hpp b/src/riptide/cpp/downsample.hpp similarity index 100% rename from riptide/cpp/downsample.hpp rename to src/riptide/cpp/downsample.hpp diff --git a/riptide/cpp/kernels.hpp b/src/riptide/cpp/kernels.hpp similarity index 100% rename from riptide/cpp/kernels.hpp rename to src/riptide/cpp/kernels.hpp diff --git a/riptide/cpp/periodogram.hpp b/src/riptide/cpp/periodogram.hpp similarity index 100% rename from riptide/cpp/periodogram.hpp rename to src/riptide/cpp/periodogram.hpp diff --git a/riptide/cpp/python_bindings.cpp b/src/riptide/cpp/python_bindings.cpp similarity index 100% rename from riptide/cpp/python_bindings.cpp rename to src/riptide/cpp/python_bindings.cpp diff --git a/riptide/cpp/running_median.hpp b/src/riptide/cpp/running_median.hpp similarity index 100% rename from riptide/cpp/running_median.hpp rename to src/riptide/cpp/running_median.hpp diff --git a/riptide/cpp/snr.hpp b/src/riptide/cpp/snr.hpp similarity index 100% rename from riptide/cpp/snr.hpp rename to src/riptide/cpp/snr.hpp diff --git a/riptide/cpp/transforms.hpp b/src/riptide/cpp/transforms.hpp similarity index 100% rename from riptide/cpp/transforms.hpp rename to src/riptide/cpp/transforms.hpp diff --git a/riptide/ffautils.py b/src/riptide/ffautils.py similarity index 99% rename from riptide/ffautils.py rename to src/riptide/ffautils.py index 9e8944a..0fef1e1 100644 --- a/riptide/ffautils.py +++ b/src/riptide/ffautils.py @@ -1,5 +1,6 @@ import numpy as np + def generate_width_trials(nbins, ducy_max=0.20, wtsp=1.5): widths = [] w = 1 diff --git a/riptide/folding.py b/src/riptide/folding.py similarity index 88% rename from riptide/folding.py rename to src/riptide/folding.py index 78c8f08..9999c16 100644 --- a/riptide/folding.py +++ b/src/riptide/folding.py @@ -9,7 +9,9 @@ def downsample_vertical(X, factor): if not factor > 1: raise ValueError("factor must be > 1") if not factor < m: - raise ValueError("factor must be strictly smaller than the number of input lines") + raise ValueError( + "factor must be strictly smaller than the number of input lines" + ) Y = np.ascontiguousarray(X.T) out = np.asarray([downsample(arr, factor) for arr in Y]) @@ -29,8 +31,8 @@ def fold(ts, period, bins, subints=None): bins : int Number of phase bins subints : int or None, optional - Number of desired sub-integrations. If None, the number of - sub-integrations will be the number of full periods that fit in + Number of desired sub-integrations. If None, the number of + sub-integrations will be the number of full periods that fit in the data (default: None) Returns @@ -38,7 +40,7 @@ def fold(ts, period, bins, subints=None): folded : ndarray The folded data as a numpy array. If subints > 1, it has a shape (subints, bins). Otherwise it is a 1D array with 'bins' elements. - + Raises ------ ValueError: if the data cannot be folded with the requested parameters, @@ -59,13 +61,15 @@ def fold(ts, period, bins, subints=None): full_periods = ts.length / period if subints > full_periods: - raise ValueError(f"subints ({subints}) exceeds the number of signal periods that fit in the data ({full_periods})") + raise ValueError( + f"subints ({subints}) exceeds the number of signal periods that fit in the data ({full_periods})" + ) factor = tbin / ts.tsamp tsdown = ts.downsample(factor) m = tsdown.nsamp // bins nsamp_eff = m * bins - + folded = tsdown.data[:nsamp_eff].reshape(m, bins) folded *= (m * factor) ** -0.5 diff --git a/riptide/libffa.py b/src/riptide/libffa.py similarity index 91% rename from riptide/libffa.py rename to src/riptide/libffa.py index f2d8884..251ccae 100644 --- a/riptide/libffa.py +++ b/src/riptide/libffa.py @@ -13,7 +13,7 @@ def generate_signal(nsamp, period, phi0=0.5, ducy=0.02, amplitude=10.0, stdnoise=1.0): - """ Generate a time series containing a periodic signal with a von Mises + """Generate a time series containing a periodic signal with a von Mises pulse profile. This function is useful for test purposes. Parameters @@ -30,16 +30,16 @@ def generate_signal(nsamp, period, phi0=0.5, ducy=0.02, amplitude=10.0, stdnoise True amplitude of the signal as defined in the reference paper. The *expectation* of the S/N of the generated signal is S/N_true = amplitude / stdnoise, - assuming that a matched filter with the exact shape of the pulse is - employed to measure S/N (here: von Mises with given duty cycle). + assuming that a matched filter with the exact shape of the pulse is + employed to measure S/N (here: von Mises with given duty cycle). riptide employs boxcar filters in the search, which results in a slight - S/N loss. See the reference paper for details. + S/N loss. See the reference paper for details. A further degradation will be observed on bright signals, because - they bias the estimation of the mean and standard deviation of the + they bias the estimation of the mean and standard deviation of the noise in a blind search. stdnoise : float, optional - Standard deviation of the background noise. If set to 0, a noiseless + Standard deviation of the background noise. If set to 0, a noiseless signal is generated. Returns @@ -48,14 +48,14 @@ def generate_signal(nsamp, period, phi0=0.5, ducy=0.02, amplitude=10.0, stdnoise Output time series. """ # von mises parameter - kappa = log(2.0) / (2.0 * sin(pi*ducy/2.0)**2) + kappa = log(2.0) / (2.0 * sin(pi * ducy / 2.0) ** 2) # Generate pulse train phase_radians = (np.arange(nsamp, dtype=float) / period - phi0) * (2 * pi) - signal = exp(kappa*(cos(phase_radians) - 1.0)) + signal = exp(kappa * (cos(phase_radians) - 1.0)) # Normalise to unit L2-norm, then scale by amplitude - scale_factor = amplitude * (signal ** 2).sum() ** -0.5 + scale_factor = amplitude * (signal**2).sum() ** -0.5 signal *= scale_factor # Add noise @@ -69,7 +69,7 @@ def generate_signal(nsamp, period, phi0=0.5, ducy=0.02, amplitude=10.0, stdnoise def ffa2(data): - """ + """ Compute the FFA transform of a two-dimensional input Parameters @@ -92,7 +92,7 @@ def ffa2(data): def ffa1(data, p): - """ + """ Compute the FFA transform of a one-dimensional input (time series) at base period p @@ -108,7 +108,7 @@ def ffa1(data, p): Returns ------- transform : ndarray (2D) - The FFA transform of 'data', as a float32 2D array of shape (m, p), + The FFA transform of 'data', as a float32 2D array of shape (m, p), where m is the number of complete pulse periods in the data See Also @@ -123,7 +123,7 @@ def ffa1(data, p): if p > data.size: raise ValueError("p must be smaller than the total number of samples") m = data.size // p - return ffa2(data[:m*p].reshape(m, p)) + return ffa2(data[: m * p].reshape(m, p)) def ffafreq(N, p, dt=1.0): @@ -164,7 +164,7 @@ def ffafreq(N, p, dt=1.0): f = np.asarray([f0]) else: s = np.arange(m) - f = (f0 - s / (m-1.0) * f0**2) + f = f0 - s / (m - 1.0) * f0**2 f /= dt return f @@ -192,8 +192,8 @@ def ffaprd(N, p, dt=1.0): def boxcar_snr(data, widths, stdnoise=1.0): - """ - Compute the S/N ratio of pulse profile(s) for a range of + """ + Compute the S/N ratio of pulse profile(s) for a range of boxcar width trials. Parameters @@ -226,7 +226,7 @@ def boxcar_snr(data, widths, stdnoise=1.0): def downsample(data, factor): - """ Downsample an array by a real-valued factor. + """Downsample an array by a real-valued factor. Parameters ---------- diff --git a/riptide/metadata.py b/src/riptide/metadata.py similarity index 64% rename from riptide/metadata.py rename to src/riptide/metadata.py index 80353eb..f6172c0 100644 --- a/riptide/metadata.py +++ b/src/riptide/metadata.py @@ -9,22 +9,21 @@ SCHEMA_ITEMS = { - Optional('source_name') : Or(str, None), - Optional('skycoord') : Or(SkyCoord, None), - Optional('dm'): Or(And(float, lambda x: x >= 0), None), - Optional('mjd'): Or(And(float, lambda x: x >= 0), None), - Optional('tobs'): Or(And(float, lambda x: x > 0), None), - Optional('fname'): Or(str, None), - + Optional("source_name"): Or(str, None), + Optional("skycoord"): Or(SkyCoord, None), + Optional("dm"): Or(And(float, lambda x: x >= 0), None), + Optional("mjd"): Or(And(float, lambda x: x >= 0), None), + Optional("tobs"): Or(And(float, lambda x: x > 0), None), + Optional("fname"): Or(str, None), # Accept any extra keys of type string with JSON-serializable values - Optional(str): json.dumps - } + Optional(str): json.dumps, +} SCHEMA = Schema(SCHEMA_ITEMS, ignore_extra_keys=True) class Metadata(dict): - """ + """ A dict subclass that carries information about an observation across all data products (TimeSeries, Periodogram, etc.) @@ -42,17 +41,18 @@ class Metadata(dict): If any of the above keys are NOT present, they will be set to None in the Metadata object. """ + def __init__(self, items={}): SCHEMA.validate(items) super(Metadata, self).__init__(items) - + for k in SCHEMA_ITEMS: if isinstance(k.schema, str): self.setdefault(k.schema, None) @classmethod def from_presto_inf(cls, inf): - """ + """ Create Metadata object from PRESTO .inf file or PrestoInf object Parameters @@ -65,14 +65,14 @@ def from_presto_inf(cls, inf): inf = PrestoInf(inf) attrs = dict(inf) - attrs['skycoord'] = inf.skycoord - attrs['fname'] = os.path.realpath(inf.fname) - attrs['tobs'] = attrs['tsamp'] * attrs['nsamp'] + attrs["skycoord"] = inf.skycoord + attrs["fname"] = os.path.realpath(inf.fname) + attrs["tobs"] = attrs["tsamp"] * attrs["nsamp"] return cls(attrs) @classmethod def from_sigproc(cls, sh, extra_keys={}): - """ + """ Create Metadata object from SIGPROC dedispersed time series file, or SigprocHeader object. @@ -85,24 +85,30 @@ def from_sigproc(cls, sh, extra_keys={}): if type(sh) == str: sh = SigprocHeader(sh, extra_keys=extra_keys) - if sh['nchans'] > 1: - raise ValueError(f"File {sh.fname!r} contains multi-channel data (nchans = {sh['nchans']}), instead of a dedispersed time series") + if sh["nchans"] > 1: + raise ValueError( + f"File {sh.fname!r} contains multi-channel data (nchans = {sh['nchans']}), instead of a dedispersed time series" + ) # Make sure this is a 32-bit dedispersed time series # We support either: 32-bit float data, or 8-bit data but only if signedness is specified in the header - nbits = sh['nbits'] + nbits = sh["nbits"] if not nbits in {8, 32}: - raise ValueError(f"Only 8-bit and 32-bit SIGPROC data are supported. File {sh.fname!r} contains {nbits}-bit data") - if nbits == 8 and 'signed' not in sh: - raise ValueError(f"SIGPROC Header says this is 8-bit data, but does not specify its signedness via the 'signed' key") + raise ValueError( + f"Only 8-bit and 32-bit SIGPROC data are supported. File {sh.fname!r} contains {nbits}-bit data" + ) + if nbits == 8 and "signed" not in sh: + raise ValueError( + f"SIGPROC Header says this is 8-bit data, but does not specify its signedness via the 'signed' key" + ) attrs = dict(sh).copy() - attrs['dm'] = attrs.get('refdm', None) - attrs['skycoord'] = sh.skycoord - attrs['source_name'] = attrs.get('source_name', None) - attrs['mjd'] = attrs.get('tstart', None) - attrs['fname'] = os.path.realpath(sh.fname) - attrs['tobs'] = sh.tobs + attrs["dm"] = attrs.get("refdm", None) + attrs["skycoord"] = sh.skycoord + attrs["source_name"] = attrs.get("source_name", None) + attrs["mjd"] = attrs.get("tstart", None) + attrs["fname"] = os.path.realpath(sh.fname) + attrs["tobs"] = sh.tobs return cls(attrs) def to_dict(self): @@ -113,7 +119,7 @@ def from_dict(cls, items): return cls(items) def __str__(self): - return 'Metadata %s' % pprint.pformat(dict(self)) + return "Metadata %s" % pprint.pformat(dict(self)) def __repr__(self): return str(self) diff --git a/riptide/peak_detection.py b/src/riptide/peak_detection.py similarity index 80% rename from riptide/peak_detection.py rename to src/riptide/peak_detection.py index 1949504..a94a6cc 100644 --- a/riptide/peak_detection.py +++ b/src/riptide/peak_detection.py @@ -8,38 +8,39 @@ from riptide.timing import timing -log = logging.getLogger('riptide.peak_detection') +log = logging.getLogger("riptide.peak_detection") class Peak(typing.NamedTuple): - """ - A simple NamedTuple with the essential parameters of a peak found + """ + A simple NamedTuple with the essential parameters of a peak found in a Periodogram """ + period: float freq: float width: int ducy: float # duty cycle - iw: int # width trial index - ip: int # period trial index + iw: int # width trial index + ip: int # period trial index snr: float dm: float def summary_dict(self): - """ + """ Returns a minimal dictionary of attributes to be written as CSV - by the pipeline + by the pipeline """ - attrs = ('period', 'freq', 'dm', 'width', 'ducy', 'snr') + attrs = ("period", "freq", "dm", "width", "ducy", "snr") return {a: getattr(self, a) for a in attrs} def segment_stats(f, s, T, segwidth=5.0): """ - Cut a periodogram in consecutive, equal-sized segments with a + Cut a periodogram in consecutive, equal-sized segments with a frequency span equal to segwidth / T, and return the centre frequencies, median S/N and robust S/N standard deviation of all segments. - + This information is then used to fit a sensible peak selection threshold as a function of frequency. @@ -65,16 +66,16 @@ def segment_stats(f, s, T, segwidth=5.0): range of the segment's S/N distribution (stddev = IQR / 1.349) """ w = segwidth / T - #log.debug("Segment width (Hz): {:.6f}".format(w)) + # log.debug("Segment width (Hz): {:.6f}".format(w)) # NOTE: the spacing of frequency trials is almost constant - m = ceil(abs(f[-1] - f[0]) / w) # number of segments - #log.debug("Segments: {:d}".format(m)) + m = ceil(abs(f[-1] - f[0]) / w) # number of segments + # log.debug("Segments: {:d}".format(m)) - p = len(f) // m # number of complete segments - #log.debug("Points/segment: {:d}".format(p)) + p = len(f) // m # number of complete segments + # log.debug("Points/segment: {:d}".format(p)) - n = m * p # effective number of elements + n = m * p # effective number of elements f = f[:n] s = s[:n] @@ -108,7 +109,9 @@ def fit_threshold(fc, tc, polydeg=2): return np.poly1d(coeffs) -def find_peaks_single(f, s, T, smin=6.0, segwidth=5.0, nstd=7.0, minseg=10, polydeg=2, clrad=0.1): +def find_peaks_single( + f, s, T, smin=6.0, segwidth=5.0, nstd=7.0, minseg=10, polydeg=2, clrad=0.1 +): """ Find peaks in a single pulse width trial. Returns a list of array indices that correspond to peak centres @@ -123,7 +126,7 @@ def find_peaks_single(f, s, T, smin=6.0, segwidth=5.0, nstd=7.0, minseg=10, poly if len(fc) >= minseg: poly = fit_threshold(fc, sc, polydeg=polydeg) polyco = poly.coefficients - else: # constant threshold if not enough points for fit + else: # constant threshold if not enough points for fit polyco = [smin] poly = np.poly1d(polyco) @@ -143,7 +146,9 @@ def find_peaks_single(f, s, T, smin=6.0, segwidth=5.0, nstd=7.0, minseg=10, poly @timing -def find_peaks(pgram, smin=6.0, segwidth=5.0, nstd=6.0, minseg=10, polydeg=2, clrad=0.1): +def find_peaks( + pgram, smin=6.0, segwidth=5.0, nstd=6.0, minseg=10, polydeg=2, clrad=0.1 +): """ Identify significant peaks in a periodogram using a dynamically fitted S/N selection threshold. The fitting involves the following procedure for @@ -151,23 +156,23 @@ def find_peaks(pgram, smin=6.0, segwidth=5.0, nstd=6.0, minseg=10, polydeg=2, cl 1. Cut the frequency range covered by the periodogram in segments of length 1 / T_obs - 2. Get the median S/N 'm' and robust S/N standard deviation 's' of each + 2. Get the median S/N 'm' and robust S/N standard deviation 's' of each segment. The dynamic selection threshold for that segment should be t = m + nstd x s - 3. Fit a polynomial in log(f) to the control points (f_i, t_i) thus + 3. Fit a polynomial in log(f) to the control points (f_i, t_i) thus obtained - 4. Any point whose S/N exceeds both the dynamic threshold and the value + 4. Any point whose S/N exceeds both the dynamic threshold and the value 'smin' are considered significant 5. Cluster these points. Two points are in the same peak if their trial frequencies are within clrad / T_obs of each other. All such clusters constitute a Peak. - + Parameters ---------- pgram : Periodogram Input periodogram to search for peaks smin : float, optional - Minimum S/N that a peak must exceed, in addition to having to exceed + Minimum S/N that a peak must exceed, in addition to having to exceed the dynamic selection threshold segwidth : float, optional Width of a frequency segment in units of 1 / T_obs @@ -192,15 +197,22 @@ def find_peaks(pgram, smin=6.0, segwidth=5.0, nstd=6.0, minseg=10, polydeg=2, cl """ f = pgram.freqs T = pgram.tobs - dm = pgram.metadata['dm'] + dm = pgram.metadata["dm"] peaks = [] polycos = {} for iw, width in enumerate(pgram.widths): s = pgram.snrs[:, iw].astype(float) cur_peak_indices, cur_polycos = find_peaks_single( - f, s, T, - smin=smin, segwidth=segwidth, nstd=nstd, minseg=minseg, polydeg=polydeg, clrad=clrad + f, + s, + T, + smin=smin, + segwidth=segwidth, + nstd=nstd, + minseg=minseg, + polydeg=polydeg, + clrad=clrad, ) for ipeak in cur_peak_indices: peak_freq = f[ipeak] @@ -211,12 +223,18 @@ def find_peaks(pgram, smin=6.0, segwidth=5.0, nstd=6.0, minseg=10, polydeg=2, cl # have np.float32 type which causes trouble down the line # NOTE 2: dm can be None on fake time series peak = Peak( - freq=float(peak_freq), period=float(1.0/peak_freq), width=int(width), - ducy=float(peak_ducy), iw=int(iw), ip=int(ipeak), snr=float(s[ipeak]), - dm=dm) - #log.debug(peak) + freq=float(peak_freq), + period=float(1.0 / peak_freq), + width=int(width), + ducy=float(peak_ducy), + iw=int(iw), + ip=int(ipeak), + snr=float(s[ipeak]), + dm=dm, + ) + # log.debug(peak) peaks.append(peak) polycos[iw] = cur_polycos - + peaks = sorted(peaks, key=lambda p: p.snr, reverse=True) return peaks, polycos diff --git a/riptide/periodogram.py b/src/riptide/periodogram.py similarity index 67% rename from riptide/periodogram.py rename to src/riptide/periodogram.py index 5531e23..1cd2acf 100644 --- a/riptide/periodogram.py +++ b/src/riptide/periodogram.py @@ -7,8 +7,8 @@ class Periodogram(object): - """ - Stores the raw output of the FFA search of a time series. + """ + Stores the raw output of the FFA search of a time series. Attributes ---------- @@ -26,6 +26,7 @@ class Periodogram(object): Two dimensional array with shape (num_periods, num_widths) containing the S/N as a function of trial pulse width and period. """ + def __init__(self, widths, periods, foldbins, snrs, metadata=None): self.widths = widths self.periods = periods @@ -35,26 +36,32 @@ def __init__(self, widths, periods, foldbins, snrs, metadata=None): @property def freqs(self): - """ Sequence of trial frequencies in Hz, in **decreasing** order """ + """Sequence of trial frequencies in Hz, in **decreasing** order""" return 1.0 / self.periods @property def tobs(self): - """ Length in seconds of the TimeSeries that was searched """ - return self.metadata['tobs'] + """Length in seconds of the TimeSeries that was searched""" + return self.metadata["tobs"] def to_dict(self): return { - 'widths': self.widths, - 'periods': self.periods, - 'foldbins': self.foldbins, - 'snrs': self.snrs, - 'metadata': self.metadata + "widths": self.widths, + "periods": self.periods, + "foldbins": self.foldbins, + "snrs": self.snrs, + "metadata": self.metadata, } @classmethod def from_dict(cls, items): - return cls(items['widths'], items['periods'], items['foldbins'], items['snrs'], metadata=items['metadata']) + return cls( + items["widths"], + items["periods"], + items["foldbins"], + items["snrs"], + metadata=items["metadata"], + ) def plot(self, iwidth=None): """ @@ -63,7 +70,7 @@ def plot(self, iwidth=None): Parameters ---------- iwidth : int or None, optional - Display only the data for this specific pulse width trial index. + Display only the data for this specific pulse width trial index. If None, for each trial period, plot the highest S/N across all trial pulse widths. """ if iwidth is None: @@ -71,25 +78,25 @@ def plot(self, iwidth=None): else: snr = self.snrs[:, iwidth] - plt.plot(self.periods, snr, marker='o', markersize=2, alpha=0.5) + plt.plot(self.periods, snr, marker="o", markersize=2, alpha=0.5) plt.xlim(self.periods.min(), self.periods.max()) - plt.xlabel('Trial Period (s)', fontsize=16) - plt.ylabel('S/N', fontsize=16) + plt.xlabel("Trial Period (s)", fontsize=16) + plt.ylabel("S/N", fontsize=16) if iwidth is None: - plt.title('Best S/N at any trial width', fontsize=18) + plt.title("Best S/N at any trial width", fontsize=18) else: width_bins = self.widths[iwidth] - plt.title('S/N at trial width = %d' % width_bins, fontsize=18) + plt.title("S/N at trial width = %d" % width_bins, fontsize=18) plt.xticks(fontsize=14) plt.yticks(fontsize=14) - plt.grid(linestyle=':') + plt.grid(linestyle=":") plt.tight_layout() - def display(self, iwidth=None, figsize=(20,5), dpi=100): + def display(self, iwidth=None, figsize=(20, 5), dpi=100): """ - Display a plot S/N versus trial period. Creates a matplotlib figure, calls `plot()` and + Display a plot S/N versus trial period. Creates a matplotlib figure, calls `plot()` and `pyplot.show()`. """ plt.figure(figsize=figsize, dpi=dpi) diff --git a/riptide/pipeline/__init__.py b/src/riptide/pipeline/__init__.py similarity index 100% rename from riptide/pipeline/__init__.py rename to src/riptide/pipeline/__init__.py diff --git a/riptide/pipeline/config/example.yaml b/src/riptide/pipeline/config/example.yaml similarity index 100% rename from riptide/pipeline/config/example.yaml rename to src/riptide/pipeline/config/example.yaml diff --git a/src/riptide/pipeline/config_validation.py b/src/riptide/pipeline/config_validation.py new file mode 100644 index 0000000..663d781 --- /dev/null +++ b/src/riptide/pipeline/config_validation.py @@ -0,0 +1,259 @@ +from schema import Schema, Use, Optional, And, Or + + +class InvalidSearchRange(Exception): + pass + + +class InvalidPipelineConfig(Exception): + pass + + +def strictly_positive(x): + return x > 0 + + +VALID_FORMATS = ("presto", "sigproc") + + +SEARCH_RANGE_SCHEMA = Schema( + { + "name": str, + "ffa_search": { + "period_min": And( + Use(float), strictly_positive, error="period_min must be a number > 0" + ), + "period_max": And( + Use(float), strictly_positive, error="period_max must be a number > 0" + ), + "bins_min": And( + int, strictly_positive, error="bins_min must be an int > 0" + ), + "bins_max": And( + int, strictly_positive, error="bins_max must be an int > 0" + ), + Optional("fpmin"): And( + int, strictly_positive, error="fpmin must be an int > 0" + ), + Optional("wtsp"): And( + Use(float), lambda x: x > 1, error="wtsp must be a number > 1" + ), + Optional("ducy_max"): And( + float, + lambda x: 0 < x < 1, + error="ducy_max must be strictly between 0 and 1", + ), + }, + "find_peaks": { + Optional("smin"): And( + Use(float), strictly_positive, error="smin must be a number > 0" + ), + Optional("segwidth"): And( + Use(float), strictly_positive, error="segwidth must be a number > 0" + ), + Optional("nstd"): And( + Use(float), strictly_positive, error="nstd must be a number > 0" + ), + Optional("minseg"): And( + int, strictly_positive, error="minseg must be an int > 0" + ), + Optional("polydeg"): And( + Use(float), strictly_positive, error="polydeg must be a number > 0" + ), + Optional("clrad"): Or( + And(Use(float), strictly_positive), + None, + error="clrad must be a number > 0", + ), + }, + "candidates": { + "bins": And( + int, strictly_positive, error="candidates.bins must be an int > 0" + ), + "subints": And( + int, strictly_positive, error="candidates.subints must be an int > 0" + ), + }, + } +) + + +PIPELINE_CONFIG_SCHEMA = Schema( + { + "processes": And(int, strictly_positive, error="processes must be an int > 0"), + "data": { + "format": Schema( + lambda x: x in VALID_FORMATS, + error=f"format must be one of {VALID_FORMATS}", + ), + "fmin": Or( + And(Use(float), strictly_positive), + None, + error="fmin must be a number > 0 or null/blank", + ), + "fmax": Or( + And(Use(float), strictly_positive), + None, + error="fmax must be a number > 0 or null/blank", + ), + "nchans": Or( + And(int, strictly_positive), + None, + error="nchans must be a number > 0 or null/blank", + ), + }, + "dmselect": { + "min": Or( + Use(float), None, error="Minimum DM must be a number or null/blank" + ), + "max": Or( + Use(float), None, error="Maximum DM must be a number or null/blank" + ), + "dmsinb_max": Or( + strictly_positive, + None, + error="dmsinb_max must be a number > 0 or null/blank", + ), + }, + "dereddening": { + "rmed_width": Schema( + strictly_positive, error="rmed_width must be a number > 0" + ), + "rmed_minpts": Schema( + strictly_positive, error="rmed_minpts must be a number > 0" + ), + }, + "ranges": [SEARCH_RANGE_SCHEMA], + "clustering": { + "radius": Schema( + strictly_positive, error="clustering radius must be a number > 0" + ), + }, + "harmonic_flagging": { + "denom_max": And( + int, strictly_positive, error="denom_max must be an int > 0" + ), + "phase_distance_max": And( + Use(float), + strictly_positive, + error="phase_distance_max must be a number > 0", + ), + "dm_distance_max": And( + Use(float), + strictly_positive, + error="dm_distance_max must be a number > 0", + ), + "snr_distance_max": And( + Use(float), + strictly_positive, + error="snr_distance_max must be a number > 0", + ), + }, + "candidate_filters": { + "dm_min": Or( + Use(float), None, error="Candidate dm_min must be a float or null/blank" + ), + "snr_min": Or( + Use(float), + None, + error="Candidate snr_min must be a float or null/blank", + ), + "remove_harmonics": Or( + bool, None, error="remove_harmonics must be a boolean or null/blank" + ), + "max_number": Or( + And(int, strictly_positive), + None, + error="Candidate max_number must be an int > 0 or null/blank", + ), + }, + "plot_candidates": Schema(bool, error="plot_candidates must be a boolean"), + } +) + + +def validate_range(rg, tsamp_max): + """ """ + # NOTE: In general, we leave the pipeline code to raise the exceptions, + # except if it takes too long for it to detect them; for example, if the number of candidate + # bins is too large, we don't want to wait until the candidate building stage to realize this. + period_min = rg["ffa_search"]["period_min"] + period_max = rg["ffa_search"]["period_max"] + bins_min = rg["ffa_search"]["bins_min"] + cand_bins = rg["candidates"]["bins"] + + if bins_min * tsamp_max > period_min: + raise InvalidSearchRange( + f"Search range {period_min:.3e} to {period_max:.3e} seconds: requested phase " + "resolution is too high w.r.t. coarsest input time series " + f"(tsamp = {tsamp_max:.3e} seconds). Use smaller bins_min or larger period_min." + ) + + if cand_bins * tsamp_max > period_min: + raise InvalidSearchRange( + f"Search range {period_min:.3e} to {period_max:.3e} seconds: " + f"cannot fold candidates with such high resolution ({cand_bins:d} bins). " + f"The coarsest input time series ({tsamp_max:.3e} seconds) does not allow it" + ) + + +def validate_ranges_contiguity(ranges): + """ """ + for a, b in zip(ranges[:-1], ranges[1:]): + period_max_a = a["ffa_search"]["period_max"] + period_min_b = b["ffa_search"]["period_min"] + if not period_max_a == period_min_b: + raise InvalidSearchRange( + "Search ranges are not either non-contiguous, or not ordered by increasing trial " + f"period (period_max ({period_max_a:.6e}) != next period_min ({period_min_b:.6e})" + ) + + +def validate_ranges(ranges, tsamp_max): + """ + Check that the search ranges are valid. Raise an exception if not. + + Parameters + ---------- + ranges : list of dict + Search ranges read from the pipeline configuration file + tsamp_max : float + Maximum sampling interval of the TimeSeries to process + + Raises + ------ + InvalidSearchRange + """ + for rg in ranges: + validate_range(rg, tsamp_max) + validate_ranges_contiguity(ranges) + + +def validate_pipeline_config(conf): + """ + Validate pipeline configuration dictionary and raise an error if it is + incorrect. This function only checks the format of the config and + the data types. + + Parameters + ---------- + conf : dict + Configuration dictionary loaded from the pipeline config file + + Returns + ------- + validated : dict + Validated configuration dictionary. Some data types may have been + changed (e.g. into to float, or float to int when both are allowed + for a config parameter). + + Raises + ------ + InvalidPipelineConfig + """ + try: + validated = PIPELINE_CONFIG_SCHEMA.validate(conf) + except Exception as ex: + # Suppress long and confusing exception chain caused by schema library + raise InvalidPipelineConfig(str(ex)) from None + return validated diff --git a/riptide/pipeline/dmiter.py b/src/riptide/pipeline/dmiter.py similarity index 76% rename from riptide/pipeline/dmiter.py rename to src/riptide/pipeline/dmiter.py index 10655bb..b6fc095 100644 --- a/riptide/pipeline/dmiter.py +++ b/src/riptide/pipeline/dmiter.py @@ -4,7 +4,7 @@ from riptide import TimeSeries, Metadata -log = logging.getLogger('riptide.pipeline.dmiter') +log = logging.getLogger("riptide.pipeline.dmiter") # This is the standard "rounded value" of the dispersion constant in use by pulsar astronomers @@ -46,13 +46,13 @@ def select_dms(trial_dms, dm_start, dm_end, fmin, fmax, nchans, wmin): fmid = (fmax + fmin) / 2.0 # tsmear = ksmear * dm - ksmear = KDM * ((fmid-cw/2)**-2 - (fmid+cw/2)**-2) + ksmear = KDM * ((fmid - cw / 2) ** -2 - (fmid + cw / 2) ** -2) # Coverage radius (in DM space) of every trial DM # Within this radius, the total smearing time is <= wmin radii = np.maximum(wmin, ksmear * trial_dms) / kdisp - def dm_gap(i, j): # assumes i <= j + def dm_gap(i, j): # assumes i <= j return (trial_dms[j] - radii[j]) - (trial_dms[i] + radii[i]) def largest_in_range(i): @@ -74,32 +74,34 @@ def largest_in_range(i): log.warning( f"The step from trial DM {trial_dms[icur]:.4f} should not exceed " f"{2 * radii[icur]:.4f}, " - f"but the next available trial DM lies farther, at {trial_dms[inext]:.4f}") + f"but the next available trial DM lies farther, at {trial_dms[inext]:.4f}" + ) selected.append(trial_dms[inext]) icur = inext return np.asarray(selected) - -def get_band_params(meta, fmt='presto'): +def get_band_params(meta, fmt="presto"): """ Returns (fmin, fmax, nchans) given a metadata dictionary loaded from a specific file format. """ - if fmt == 'presto': - fbot = meta['fbot'] - nchans = meta['nchan'] - ftop = fbot + nchans * meta['cbw'] + if fmt == "presto": + fbot = meta["fbot"] + nchans = meta["nchan"] + ftop = fbot + nchans * meta["cbw"] fmin = min(fbot, ftop) fmax = max(fbot, ftop) - elif fmt == 'sigproc': - raise ValueError("Cannot parse observing band parameters from data in sigproc format") + elif fmt == "sigproc": + raise ValueError( + "Cannot parse observing band parameters from data in sigproc format" + ) else: raise ValueError(f"Unknown format: {fmt}") return fmin, fmax, nchans -def infer_band_params(metadata_list, fmt='presto'): +def infer_band_params(metadata_list, fmt="presto"): """ Read observing band parameters of all given Metadata objects, and check that they are all the same (otherwise, raise RuntimeError). @@ -109,11 +111,12 @@ def infer_band_params(metadata_list, fmt='presto'): raise ValueError( "Cannot infer observing band parameters from empty metadata list. " "It appears no TimeSeries were passed as input." - ) + ) params = [get_band_params(md, fmt=fmt) for md in metadata_list] if not all([params[0] == p for p in params]): raise RuntimeError( - "Observing band parameters are NOT identical across all dedispersed time series") + "Observing band parameters are NOT identical across all dedispersed time series" + ) return params[0] @@ -123,13 +126,16 @@ def get_galactic_coordnates(metadata_list): they are all the same (otherwise, raise RuntimeError). Returns a float tuple (gl_deg, gb_deg). """ + def galc(md): - coord = md['skycoord'].galactic + coord = md["skycoord"].galactic return coord.l.deg, coord.b.deg ref = galc(metadata_list[0]) if not all([galc(md) == ref for md in metadata_list]): - raise RuntimeError("Coordinates are NOT identical across all dedispersed time series") + raise RuntimeError( + "Coordinates are NOT identical across all dedispersed time series" + ) return ref @@ -166,19 +172,35 @@ class DMIterator(object): """ METADATA_LOADERS = { - 'sigproc': Metadata.from_sigproc, - 'presto': Metadata.from_presto_inf + "sigproc": Metadata.from_sigproc, + "presto": Metadata.from_presto_inf, } # TODO: actually implement dmsinb_max - def __init__(self, filenames, dm_start, dm_end, dmsinb_max=45.0, fmt='presto', wmin=1.0e-3, - fmin=None, fmax=None, nchans=None): + def __init__( + self, + filenames, + dm_start, + dm_end, + dmsinb_max=45.0, + fmt="presto", + wmin=1.0e-3, + fmin=None, + fmax=None, + nchans=None, + ): mdloader = self.METADATA_LOADERS[fmt] self.metadata_list = [mdloader(fname) for fname in filenames] - self.dm_start = float(dm_start) if dm_start is not None \ - else min(md['dm'] for md in self.metadata_list) - self.dm_end = float(dm_end) if dm_end is not None \ - else max(md['dm'] for md in self.metadata_list) + self.dm_start = ( + float(dm_start) + if dm_start is not None + else min(md["dm"] for md in self.metadata_list) + ) + self.dm_end = ( + float(dm_end) + if dm_end is not None + else max(md["dm"] for md in self.metadata_list) + ) self.dmsinb_max = float(dmsinb_max) if dmsinb_max is not None else None self.fmt = fmt self.wmin = wmin @@ -190,21 +212,27 @@ def __init__(self, filenames, dm_start, dm_end, dmsinb_max=45.0, fmt='presto', w log.info( f"Applying DM|sin b| cap of {self.dmsinb_max:.4f}: " f"At b = {gb_deg:.2f} deg this means a max DM of {galactic_dm_cap:.4f}" - ) + ) self.dm_end = min(self.dm_end, galactic_dm_cap) - - log.info(f"Selecting DM trials in the range {self.dm_start:.4f} to {self.dm_end:.4f}") + + log.info( + f"Selecting DM trials in the range {self.dm_start:.4f} to {self.dm_end:.4f}" + ) # Try to infer band parameters from the data try: - (self.fmin, self.fmax, self.nchans) = infer_band_params(self.metadata_list, fmt=fmt) + (self.fmin, self.fmax, self.nchans) = infer_band_params( + self.metadata_list, fmt=fmt + ) log.info( "Inferred observing band parameters from input files: " f"fmin = {self.fmin:.3f}, fmax = {self.fmax:.3f}, nchans = {self.nchans:d}. " "Any manually specified values of fmin/fmax/nchans will be ignored." - ) + ) except (ValueError, RuntimeError) as err: - log.info(f"Could not infer observing band parameters from input files: {err!s}") + log.info( + f"Could not infer observing band parameters from input files: {err!s}" + ) log.info("Using manually specified band parameters instead") if any([param is None for param in (fmin, fmax, nchans)]): raise ValueError("You MUST specify: fmin, fmax, nchans") @@ -212,21 +240,29 @@ def __init__(self, filenames, dm_start, dm_end, dmsinb_max=45.0, fmt='presto', w (self.fmin, self.fmax, self.nchans) = (fmin, fmax, nchans) log.info( f"Using: fmin = {self.fmin:.3f}, " - f"fmax = {self.fmax:.3f}, nchans = {self.nchans:d}") + f"fmax = {self.fmax:.3f}, nchans = {self.nchans:d}" + ) - self.metadata_dict = {meta['dm']: meta for meta in self.metadata_list} + self.metadata_dict = {meta["dm"]: meta for meta in self.metadata_list} log.info( f"Selecting minimal trial DM subset to cover the DM range " - f"{self.dm_start:.4f} to {self.dm_end:.4f}") + f"{self.dm_start:.4f} to {self.dm_end:.4f}" + ) self.selected_dms = select_dms( list(self.metadata_dict.keys()), - self.dm_start, self.dm_end, self.fmin, self.fmax, self.nchans, self.wmin - ) + self.dm_start, + self.dm_end, + self.fmin, + self.fmax, + self.nchans, + self.wmin, + ) log.info( f"Selected {len(self.selected_dms)} DM trials for processing: " - f"{list(self.selected_dms)}") + f"{list(self.selected_dms)}" + ) def iterate_filenames(self, chunksize=1): """ @@ -234,19 +270,19 @@ def iterate_filenames(self, chunksize=1): """ chunk = [] for dm in self.selected_dms: - fname = self.metadata_dict[dm]['fname'] + fname = self.metadata_dict[dm]["fname"] chunk.append(fname) if len(chunk) == chunksize: yield chunk chunk = [] - if chunk: # yield any non-empty, incomplete last chunk + if chunk: # yield any non-empty, incomplete last chunk yield chunk def get_filename(self, dm): - return self.metadata_dict[dm]['fname'] + return self.metadata_dict[dm]["fname"] def tobs_median(self): - return np.median([md['tobs'] for md in self.metadata_list]) + return np.median([md["tobs"] for md in self.metadata_list]) def tsamp_max(self): - return max([md['tsamp'] for md in self.metadata_list]) + return max([md["tsamp"] for md in self.metadata_list]) diff --git a/riptide/pipeline/harmonic_testing.py b/src/riptide/pipeline/harmonic_testing.py similarity index 85% rename from riptide/pipeline/harmonic_testing.py rename to src/riptide/pipeline/harmonic_testing.py index 7878178..663d8fb 100644 --- a/riptide/pipeline/harmonic_testing.py +++ b/src/riptide/pipeline/harmonic_testing.py @@ -3,12 +3,12 @@ import logging -log = logging.getLogger('riptide.pipeline.harmonic_filter') +log = logging.getLogger("riptide.pipeline.harmonic_filter") def hdiag(F, H, tobs, fmin, fmax, denom_max=100): """ - Calculate a number of diagnostic values to evaluate whether two sets of + Calculate a number of diagnostic values to evaluate whether two sets of candidate parameters are harmonically related. Parameters @@ -29,7 +29,7 @@ def hdiag(F, H, tobs, fmin, fmax, denom_max=100): fmax: float Top effective observing frequency in Hz denom_max: int, optional - Maximum allowed denominator of the harmonic fraction by which the + Maximum allowed denominator of the harmonic fraction by which the frequencies of both candidates are related. It must be limited, otherwise there is always a rational fraction arbitrarily close to the ratio of their frequencies. (default: 100) @@ -76,23 +76,33 @@ def width(X): snr_distance = abs(H.snr - harmonic_snr_expected) return { - 'fraction': fraction, - 'phase_absdiff_turns': phase_absdiff_turns, - 'phase_distance' : phase_distance, - 'dm_absdiff': dm_absdiff, - 'dm_delay_absdiff': dm_delay_absdiff, - 'dm_distance': dm_distance, - 'harmonic_snr_expected': harmonic_snr_expected, - 'snr_distance': snr_distance, + "fraction": fraction, + "phase_absdiff_turns": phase_absdiff_turns, + "phase_distance": phase_distance, + "dm_absdiff": dm_absdiff, + "dm_delay_absdiff": dm_delay_absdiff, + "dm_distance": dm_distance, + "harmonic_snr_expected": harmonic_snr_expected, + "snr_distance": snr_distance, } -def htest(F, H, tobs, fmin, fmax, denom_max=100, phase_distance_max=1.0, dm_distance_max=3.0, snr_distance_max=3.0): +def htest( + F, + H, + tobs, + fmin, + fmax, + denom_max=100, + phase_distance_max=1.0, + dm_distance_max=3.0, + snr_distance_max=3.0, +): """ Test whether two sets of candidate parameters are harmonically related. - The code first finds the closest rational fraction p/q to the ratio + The code first finds the closest rational fraction p/q to the ratio H.freq / F.freq and then tests whether H is the plausible p/q-th harmonic - of F. The method is *purposely* designed to under-flag rather than + of F. The method is *purposely* designed to under-flag rather than over-flag, noting also that pipeline users can decide not to remove harmonics from the final candidate list. @@ -114,21 +124,21 @@ def htest(F, H, tobs, fmin, fmax, denom_max=100, phase_distance_max=1.0, dm_dist fmax: float Top effective observing frequency in Hz denom_max: int, optional - Maximum allowed denominator of the harmonic fraction by which the + Maximum allowed denominator of the harmonic fraction by which the frequencies of both candidates are related. It must be limited, otherwise there is always a rational fraction arbitrarily close to the ratio of their frequencies. (default: 100) phase_distance_max: float - Upper bound on the phase delay (in number of pulse widths) accrued + Upper bound on the phase delay (in number of pulse widths) accrued over 'tobs' seconds between the signal H and the hypothesised harmonic - p/q x F. A value of 1.0 means that the harmonic relationship is + p/q x F. A value of 1.0 means that the harmonic relationship is credible only of both trains of pulses within one pulse width of each - other. This the proper way to measure if the frequencies H.freq and + other. This the proper way to measure if the frequencies H.freq and p/q x F.freq are significantly close. (default: 1.0) dm_distance_max: float - Upper bound on the difference between dispersion delays (expressed in - pulse widths) across the observing band associated to the DMs of - F and H. (default: 3.0) + Upper bound on the difference between dispersion delays (expressed in + pulse widths) across the observing band associated to the DMs of + F and H. (default: 3.0) snr_distance_max: float Upper bound on the absolute difference between the true S/N of H and the S/N that it should have if it was the p/q harmonic of F. @@ -147,9 +157,10 @@ def htest(F, H, tobs, fmin, fmax, denom_max=100, phase_distance_max=1.0, dm_dist The rational fraction p/q closest to H.freq / F.freq """ dvals = hdiag(F, H, tobs, fmin, fmax, denom_max=denom_max) - related = \ - dvals['phase_distance'] <= phase_distance_max and \ - dvals['dm_distance'] <= dm_distance_max and \ - dvals['snr_distance'] <= snr_distance_max - fraction = dvals['fraction'] + related = ( + dvals["phase_distance"] <= phase_distance_max + and dvals["dm_distance"] <= dm_distance_max + and dvals["snr_distance"] <= snr_distance_max + ) + fraction = dvals["fraction"] return related, fraction diff --git a/riptide/pipeline/peak_cluster.py b/src/riptide/pipeline/peak_cluster.py similarity index 78% rename from riptide/pipeline/peak_cluster.py rename to src/riptide/pipeline/peak_cluster.py index 29489c9..23a54bb 100644 --- a/riptide/pipeline/peak_cluster.py +++ b/src/riptide/pipeline/peak_cluster.py @@ -2,7 +2,7 @@ class PeakCluster(list): - """ + """ Basic list subclass to store a cluster of Peak objects Parameters @@ -22,6 +22,7 @@ class PeakCluster(list): cluster's frequency and its fundamental's frequency (default: None) """ + def __init__(self, peaks, rank=None, parent_fundamental=None, hfrac=None): super(PeakCluster, self).__init__(peaks) self.rank = rank @@ -42,25 +43,23 @@ def summary_dataframe(self): objects, where the columns are the keys of the dictionary returned by the Peak.summary_dict() method """ - return pandas.DataFrame.from_dict([ - peak.summary_dict() for peak in self - ]) - + return pandas.DataFrame.from_dict([peak.summary_dict() for peak in self]) + def summary_dict(self): - """ - """ + """ """ return { **self.centre.summary_dict(), - 'npeaks': len(self), - + "npeaks": len(self), # NOTE: we set some default values when there is no fundamental, instead of None # This is to work around a limitation of pandas.DataFrame where columns with missing # values MUST be of type float, and we want type 'int' for these - 'rank': self.rank, - 'hfrac_num': self.hfrac.numerator if self.is_harmonic else 0, - 'hfrac_denom': self.hfrac.denominator if self.is_harmonic else 0, - 'fundamental_rank': self.parent_fundamental.rank if self.is_harmonic else self.rank - } + "rank": self.rank, + "hfrac_num": self.hfrac.numerator if self.is_harmonic else 0, + "hfrac_denom": self.hfrac.denominator if self.is_harmonic else 0, + "fundamental_rank": ( + self.parent_fundamental.rank if self.is_harmonic else self.rank + ), + } def __str__(self): name = type(self).__name__ @@ -78,8 +77,19 @@ def clusters_to_dataframe(clusters): """ clusters = sorted(clusters, key=lambda c: c.centre.snr, reverse=True) df = pandas.DataFrame.from_dict([cl.summary_dict() for cl in clusters]) - + # Re-order columns - columns = ['rank', 'period', 'dm', 'snr', 'ducy', 'freq', 'npeaks', 'hfrac_num', 'hfrac_denom', 'fundamental_rank'] + columns = [ + "rank", + "period", + "dm", + "snr", + "ducy", + "freq", + "npeaks", + "hfrac_num", + "hfrac_denom", + "fundamental_rank", + ] df = df[columns] - return df \ No newline at end of file + return df diff --git a/riptide/pipeline/pipeline.py b/src/riptide/pipeline/pipeline.py similarity index 72% rename from riptide/pipeline/pipeline.py rename to src/riptide/pipeline/pipeline.py index c0a1ee1..205da05 100644 --- a/riptide/pipeline/pipeline.py +++ b/src/riptide/pipeline/pipeline.py @@ -26,7 +26,7 @@ from riptide.timing import timing -log = logging.getLogger('riptide.pipeline') +log = logging.getLogger("riptide.pipeline") class CandidateWriter(object): @@ -34,6 +34,7 @@ class CandidateWriter(object): func-like object to be used in conjunction with multiprocessing.Pool to write candidates with multiple processes """ + def __init__(self, outdir, plot=False): self.outdir = os.path.realpath(outdir) self.plot = plot @@ -44,11 +45,11 @@ def __call__(self, arg): (int, Candidate) """ rank, cand = arg - fname = os.path.join(self.outdir, f'candidate_{rank:04d}.json') + fname = os.path.join(self.outdir, f"candidate_{rank:04d}.json") log.debug(f"Saving to {fname}: {cand}") save_json(fname, cand) if self.plot: - fname = os.path.join(self.outdir, f'candidate_{rank:04d}.png') + fname = os.path.join(self.outdir, f"candidate_{rank:04d}.png") log.debug(f"Saving plot to {fname}") cand.savefig(fname) @@ -62,6 +63,7 @@ class Pipeline(object): conf: dict Configuration dictionary loaded from YAML file """ + def __init__(self, config): # This only validates the format, not the actual parameter values. # More checks are performed later when the parameters of the input @@ -75,10 +77,10 @@ def __init__(self, config): self.candidates = [] def wmin(self): - """ Minimum pulse width being searched for """ - search_ranges = self.config['ranges'] + """Minimum pulse width being searched for""" + search_ranges = self.config["ranges"] min_widths = [ - kw['ffa_search']['period_min'] / kw['ffa_search']['bins_min'] + kw["ffa_search"]["period_min"] / kw["ffa_search"]["bins_min"] for kw in search_ranges ] return min(min_widths) @@ -106,12 +108,11 @@ def get_search_range(self, period): # The code below can return wrong results if the ranges do not connect # perfectly with each other ranges = sorted( - self.config['ranges'], - key=lambda r: r['ffa_search']['period_max'] - ) - - pmin_global = min(rng['ffa_search']['period_min'] for rng in ranges) - pmax_global = max(rng['ffa_search']['period_max'] for rng in ranges) + self.config["ranges"], key=lambda r: r["ffa_search"]["period_max"] + ) + + pmin_global = min(rng["ffa_search"]["period_min"] for rng in ranges) + pmax_global = max(rng["ffa_search"]["period_max"] for rng in ranges) if period < pmin_global: msg = ( @@ -127,8 +128,8 @@ def get_search_range(self, period): return dict(ranges[-1]) for rng in ranges: - pmin = rng['ffa_search']['period_min'] - pmax = rng['ffa_search']['period_max'] + pmin = rng["ffa_search"]["period_min"] + pmax = rng["ffa_search"]["period_max"] if pmin <= period < pmax: return dict(rng) @@ -149,30 +150,31 @@ def prepare(self, files): # - yielding them in chunks of size = number of parallel processes self.dmiter = DMIterator( files, - conf['dmselect']['min'], - conf['dmselect']['max'], - dmsinb_max=conf['dmselect']['dmsinb_max'], - fmt=conf['data']['format'], + conf["dmselect"]["min"], + conf["dmselect"]["max"], + dmsinb_max=conf["dmselect"]["dmsinb_max"], + fmt=conf["data"]["format"], wmin=self.wmin(), - fmin=conf['data']['fmin'], - fmax=conf['data']['fmax'], - nchans=conf['data']['nchans'], + fmin=conf["data"]["fmin"], + fmax=conf["data"]["fmax"], + nchans=conf["data"]["nchans"], ) tsamp_max = self.dmiter.tsamp_max() log.info( - f"Max sampling time = {tsamp_max:.6e} s, checking pipeline config parameter values") - validate_ranges(conf['ranges'], tsamp_max) + f"Max sampling time = {tsamp_max:.6e} s, checking pipeline config parameter values" + ) + validate_ranges(conf["ranges"], tsamp_max) # NOTE: call dmiter.prepare() first. Before that, dmiter.tsloader is None self.worker_pool = WorkerPool( - conf['dereddening'], - conf['ranges'], - processes=conf['processes'], - fmt=conf['data']['format'] + conf["dereddening"], + conf["ranges"], + processes=conf["processes"], + fmt=conf["data"]["format"], ) log.info("Pipeline ready") - + @timing def search(self): """ @@ -180,10 +182,8 @@ def search(self): """ log.info("Running search") peaks = [] - for fnames in self.dmiter.iterate_filenames(chunksize=self.config['processes']): - peaks.extend( - self.worker_pool.process_fname_list(fnames) - ) + for fnames in self.dmiter.iterate_filenames(chunksize=self.config["processes"]): + peaks.extend(self.worker_pool.process_fname_list(fnames)) self.peaks = sorted(peaks, key=lambda p: p.period) log.info("Total peaks found: {}".format(len(peaks))) log.info("Search complete") @@ -197,7 +197,7 @@ def cluster_peaks(self): log.info("Clustering peaks") conf = self.config tmed = self.dmiter.tobs_median() - clrad = conf['clustering']['radius'] / tmed + clrad = conf["clustering"]["radius"] / tmed log.debug(f"Median Tobs = {tmed:.2f} s") log.debug(f"Frequency clustering radius = {clrad:.3e} Hz") @@ -207,8 +207,7 @@ def cluster_peaks(self): cluster_ids = cluster1d(freqs, clrad, already_sorted=True) self.clusters = [ - PeakCluster((self.peaks[ii] for ii in ids)) - for ids in cluster_ids + PeakCluster((self.peaks[ii] for ii in ids)) for ids in cluster_ids ] log.info(f"Total clusters found: {len(self.clusters)}") @@ -224,9 +223,11 @@ def flag_harmonics(self): tobs = self.dmiter.tobs_median() fmin = self.dmiter.fmin fmax = self.dmiter.fmax - kwargs = self.config['harmonic_flagging'] - - clusters_decreasing_snr = sorted(self.clusters, key=lambda c: c.centre.snr, reverse=True) + kwargs = self.config["harmonic_flagging"] + + clusters_decreasing_snr = sorted( + self.clusters, key=lambda c: c.centre.snr, reverse=True + ) # Assign ranks first for rank, cl in enumerate(clusters_decreasing_snr): @@ -242,7 +243,7 @@ def flag_harmonics(self): if related: H.parent_fundamental = F H.hfrac = fraction - + harmonics = list(filter(lambda c: c.is_harmonic, self.clusters)) log.info(f"Harmonics flagged: {len(harmonics)}") log.info(f"Fundamental clusters: {len(self.clusters) - len(harmonics)}") @@ -251,29 +252,37 @@ def flag_harmonics(self): @timing def apply_candidate_filters(self): log.info("Applying candidate filters") - params = self.config['candidate_filters'] + params = self.config["candidate_filters"] clusters_filtered = self.clusters # DM cut - dm_min = params['dm_min'] + dm_min = params["dm_min"] if dm_min is not None: log.warning(f"Applying DM threshold of {dm_min}") - clusters_filtered = list(filter(lambda c: c.centre.dm >= dm_min, clusters_filtered)) + clusters_filtered = list( + filter(lambda c: c.centre.dm >= dm_min, clusters_filtered) + ) # S/N cut - snr_min = params['snr_min'] + snr_min = params["snr_min"] if snr_min is not None: log.warning(f"Applying S/N threshold of {snr_min}") - clusters_filtered = list(filter(lambda c: c.centre.snr >= snr_min, clusters_filtered)) + clusters_filtered = list( + filter(lambda c: c.centre.snr >= snr_min, clusters_filtered) + ) # Harmonic removal - if params['remove_harmonics']: - log.warning("Harmonic removal is enabled, clusters flagged as harmonics will NOT be output as candidates") - clusters_filtered = list(filter(lambda c: not c.is_harmonic, clusters_filtered)) + if params["remove_harmonics"]: + log.warning( + "Harmonic removal is enabled, clusters flagged as harmonics will NOT be output as candidates" + ) + clusters_filtered = list( + filter(lambda c: not c.is_harmonic, clusters_filtered) + ) # Cap on number of candidates to build - nmax = params['max_number'] + nmax = params["max_number"] if nmax: if len(clusters_filtered) > nmax: nleft = len(clusters_filtered) @@ -282,17 +291,22 @@ def apply_candidate_filters(self): f"Number of clusters remaining ({nleft}) exceeds the maximum specified number of candidates ({nmax}). " f"The faintest {nexcess} will not be saved as candidates" ) - clusters_filtered = sorted(clusters_filtered, key=lambda c: c.centre.snr, reverse=True) + clusters_filtered = sorted( + clusters_filtered, key=lambda c: c.centre.snr, reverse=True + ) clusters_filtered = clusters_filtered[:nmax] self.clusters_filtered = clusters_filtered - log.info(f"Candidate filters applied. Clusters remaining: {len(self.clusters_filtered)}") - + log.info( + f"Candidate filters applied. Clusters remaining: {len(self.clusters_filtered)}" + ) @timing def build_candidates(self): - log.info("Building candidates") - clusters_decreasing_snr = sorted(self.clusters_filtered, key=lambda c: c.centre.snr, reverse=True) + log.info("Building candidates") + clusters_decreasing_snr = sorted( + self.clusters_filtered, key=lambda c: c.centre.snr, reverse=True + ) if not clusters_decreasing_snr: log.info("No clusters: no candidates to build") @@ -304,21 +318,28 @@ def build_candidates(self): for cl in clusters_decreasing_snr: grouped_clusters[cl.centre.dm].append(cl) - log.debug(f"{len(clusters_decreasing_snr)} candidates to build from {len(grouped_clusters)} TimeSeries") + log.debug( + f"{len(clusters_decreasing_snr)} candidates to build from {len(grouped_clusters)} TimeSeries" + ) for dm, clusters in grouped_clusters.items(): fname = self.dmiter.get_filename(dm) ts = self.worker_pool.loader(fname) ts = ts.deredden( - width=self.config['dereddening']['rmed_width'], - minpts=self.config['dereddening']['rmed_minpts']) + width=self.config["dereddening"]["rmed_width"], + minpts=self.config["dereddening"]["rmed_minpts"], + ) ts = ts.normalise() for cl in clusters: try: rng = self.get_search_range(cl.centre.period) cand = Candidate.from_pipeline_output( - ts, cl, rng['candidates']['bins'], subints=rng['candidates']['subints']) + ts, + cl, + rng["candidates"]["bins"], + subints=rng["candidates"]["subints"], + ) self.candidates.append(cand) # NOTE: this should never happen ideally @@ -328,14 +349,15 @@ def build_candidates(self): log.error(err) log.error(traceback.format_exc()) - self.candidates = sorted(self.candidates, key=lambda c: c.params['snr'], reverse=True) + self.candidates = sorted( + self.candidates, key=lambda c: c.params["snr"], reverse=True + ) log.info(f"Total candidates: {len(self.candidates)}") log.info("Done building candidates") @timing def save_products(self, outdir=os.getcwd()): - """ - """ + """ """ log.info("Building products") if not self.peaks: @@ -343,32 +365,34 @@ def save_products(self, outdir=os.getcwd()): return # CSV of Peak data - df_peaks = pandas.DataFrame.from_dict([ - peak.summary_dict() for peak in self.peaks - ]) - df_peaks_fname = os.path.join(outdir, 'peaks.csv') - df_peaks.to_csv(df_peaks_fname, sep=',', index=False, float_format='%.9f') + df_peaks = pandas.DataFrame.from_dict( + [peak.summary_dict() for peak in self.peaks] + ) + df_peaks_fname = os.path.join(outdir, "peaks.csv") + df_peaks.to_csv(df_peaks_fname, sep=",", index=False, float_format="%.9f") log.info("Saved Peak data to {!r}".format(df_peaks_fname)) ### CSV of cluster data if self.clusters: df_clusters = clusters_to_dataframe(self.clusters) - df_clusters_fname = os.path.join(outdir, 'clusters.csv') - df_clusters.to_csv(df_clusters_fname, sep=',', index=False, float_format='%.9f') + df_clusters_fname = os.path.join(outdir, "clusters.csv") + df_clusters.to_csv( + df_clusters_fname, sep=",", index=False, float_format="%.9f" + ) log.info("Saved Cluster data to {!r}".format(df_peaks_fname)) ### CSV of basic candidate parameters if self.candidates: - df_cands = pandas.DataFrame.from_dict([ - cand.params for cand in self.candidates - ]) - df_cands_fname = os.path.join(outdir, 'candidates.csv') - df_cands.to_csv(df_cands_fname, sep=',', index=False, float_format='%.9f') + df_cands = pandas.DataFrame.from_dict( + [cand.params for cand in self.candidates] + ) + df_cands_fname = os.path.join(outdir, "candidates.csv") + df_cands.to_csv(df_cands_fname, sep=",", index=False, float_format="%.9f") ### Candidates and candidate plots log.info("Writing candidate files") - pool = multiprocessing.Pool(processes=self.config['processes']) - writer = CandidateWriter(outdir, plot=self.config['plot_candidates']) + pool = multiprocessing.Pool(processes=self.config["processes"]) + writer = CandidateWriter(outdir, plot=self.config["plot_candidates"]) arglist = [(rank, cand) for rank, cand in enumerate(self.candidates)] pool.map(writer, arglist) @@ -377,7 +401,7 @@ def save_products(self, outdir=os.getcwd()): # the pool pool.close() pool.join() - log.info("Data products written") + log.info("Data products written") @timing def process(self, files, outdir): @@ -396,7 +420,7 @@ def process(self, files, outdir): @classmethod def from_yaml_config(cls, fname): log.debug("Creating pipeline from config file: {}".format(fname)) - with open(fname, 'r') as fobj: + with open(fname, "r") as fobj: conf = yaml.safe_load(fobj) log.debug("Pipeline configuration: {}".format(json.dumps(conf, indent=4))) return cls(conf) @@ -405,12 +429,14 @@ def from_yaml_config(cls, fname): ############################################################################### -help_formatter = lambda prog: argparse.ArgumentDefaultsHelpFormatter(prog, max_help_position=16) +help_formatter = lambda prog: argparse.ArgumentDefaultsHelpFormatter( + prog, max_help_position=16 +) def get_parser(): def outdir(path): - """ Function that checks the outdir argument """ + """Function that checks the outdir argument""" if not os.path.isdir(path): msg = "Specified output directory {!r} does not exist".format(path) raise argparse.ArgumentTypeError(msg) @@ -418,7 +444,7 @@ def outdir(path): parser = argparse.ArgumentParser( formatter_class=help_formatter, - description=f"Search multiple DM trials with the riptide end-to-end FFA pipeline." + description=f"Search multiple DM trials with the riptide end-to-end FFA pipeline.", ) parser.add_argument( "-c", @@ -435,34 +461,34 @@ def outdir(path): help="Output directory for the data products", ) parser.add_argument( - "-f", + "-f", "--logfile", type=str, default=None, - help="Save logs to given file. If not specified, no logfile is saved" + help="Save logs to given file. If not specified, no logfile is saved", ) parser.add_argument( "--log-level", type=str, - default='DEBUG', + default="DEBUG", help="Logging level for the riptide logger", - choices=['DEBUG', 'INFO', 'WARNING'] + choices=["DEBUG", "INFO", "WARNING"], ) parser.add_argument( "--log-timings", - action='store_true', - help="If this flag is specified, log the execution times of all major functions" + action="store_true", + help="If this flag is specified, log the execution times of all major functions", ) + parser.add_argument("--version", action="version", version=__version__) parser.add_argument( - '--version', action='version', version=__version__ - ) - parser.add_argument("files", type=str, nargs="+", help="Input file(s) of the right format") + "files", type=str, nargs="+", help="Input file(s) of the right format" + ) return parser def run_program(args): ### Select non-interactive backend - # matplotlib.use('Agg') would not work here, due to importing order + # matplotlib.use('Agg') would not work here, due to importing order # the console_scripts entry point design means that 'riptide' is always imported first, # importing everything else in riptide's __init__.py, which ends up setting the backend # before the first line of this script is reached @@ -474,28 +500,29 @@ def run_program(args): # is called rom the unit test suite and requires a backend switch as well # (otherwise we get a crash on Travis OSX virtual machine) import matplotlib.pyplot as plt - plt.switch_backend('Agg') + + plt.switch_backend("Agg") handlers = [logging.StreamHandler()] if args.logfile: - handlers.append(logging.FileHandler(args.logfile, mode='w')) + handlers.append(logging.FileHandler(args.logfile, mode="w")) logging.basicConfig( level=args.log_level, - format='%(asctime)s %(filename)18s:%(lineno)-4s %(levelname)-8s %(message)s', - handlers=handlers + format="%(asctime)s %(filename)18s:%(lineno)-4s %(levelname)-8s %(message)s", + handlers=handlers, ) - logging.getLogger('matplotlib').setLevel('WARNING') # otherwise it can get annoying + logging.getLogger("matplotlib").setLevel("WARNING") # otherwise it can get annoying if args.log_timings: - logging.getLogger('riptide.timing').setLevel('DEBUG') + logging.getLogger("riptide.timing").setLevel("DEBUG") else: - logging.getLogger('riptide.timing').setLevel('WARNING') + logging.getLogger("riptide.timing").setLevel("WARNING") pipeline = Pipeline.from_yaml_config(args.config) pipeline.process(args.files, args.outdir) - # If you have seen the movie "The Martian" and always wanted to look like + # If you have seen the movie "The Martian" and always wanted to look like # an actual scientist to your friends and family. Thank me later. log.info("CALCULATIONS CORRECT") @@ -503,12 +530,12 @@ def run_program(args): # NOTE: main() is the entry point of the console script def main(): # NOTE (IMPORTANT): Force all numpy libraries to use a single thread/CPU - # Each DM trial is assigned to a different process, and for optimal + # Each DM trial is assigned to a different process, and for optimal # performance, each process should be limited to 1 CPU with threadpoolctl.threadpool_limits(limits=1): parser = get_parser() run_program(parser.parse_args()) -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/riptide/pipeline/worker_pool.py b/src/riptide/pipeline/worker_pool.py similarity index 73% rename from riptide/pipeline/worker_pool.py rename to src/riptide/pipeline/worker_pool.py index 575d535..ea04a70 100644 --- a/riptide/pipeline/worker_pool.py +++ b/src/riptide/pipeline/worker_pool.py @@ -4,11 +4,11 @@ from riptide import TimeSeries, ffa_search, find_peaks -log = logging.getLogger('riptide.worker_pool') +log = logging.getLogger("riptide.worker_pool") class WorkerPool(object): - """ + """ deredden_params : dict range_confs : list of dicts List of dicts from the 'ranges' section of the YAML config file @@ -22,11 +22,11 @@ class WorkerPool(object): """ TIMESERIES_LOADERS = { - 'sigproc': TimeSeries.from_sigproc, - 'presto': TimeSeries.from_presto_inf + "sigproc": TimeSeries.from_sigproc, + "presto": TimeSeries.from_presto_inf, } - def __init__(self, deredden_params, range_confs, processes=1, fmt='presto'): + def __init__(self, deredden_params, range_confs, processes=1, fmt="presto"): self.deredden_params = deredden_params self.range_confs = range_confs self.loader = self.TIMESERIES_LOADERS[fmt] @@ -47,25 +47,22 @@ def process_fname_list(self, fnames): def process_fname(self, fname): allpeaks = [] ts = self.loader(fname) - dm = ts.metadata['dm'] + dm = ts.metadata["dm"] log.debug("Searching DM = {:.3f}".format(dm)) # Make pre-processing common to all ranges to save time ts = ts.deredden( - self.deredden_params['rmed_width'], - minpts=self.deredden_params['rmed_minpts'] - ) + self.deredden_params["rmed_width"], + minpts=self.deredden_params["rmed_minpts"], + ) ts = ts.normalise() for conf in self.range_confs: - kw_search = dict(conf['ffa_search']) - kw_search.update({ - 'deredden': False, - 'already_normalised': True - }) + kw_search = dict(conf["ffa_search"]) + kw_search.update({"deredden": False, "already_normalised": True}) tsdr, pgram = ffa_search(ts, **kw_search) - peaks, polycos = find_peaks(pgram, **conf['find_peaks']) + peaks, polycos = find_peaks(pgram, **conf["find_peaks"]) allpeaks.extend(peaks) - del tsdr, pgram, peaks, polycos # Free RAM ASAP + del tsdr, pgram, peaks, polycos # Free RAM ASAP log.debug(f"Done searching DM = {dm:.3f}, peaks found: {len(allpeaks)}") - return allpeaks \ No newline at end of file + return allpeaks diff --git a/riptide/reading/__init__.py b/src/riptide/reading/__init__.py similarity index 100% rename from riptide/reading/__init__.py rename to src/riptide/reading/__init__.py diff --git a/riptide/reading/presto.py b/src/riptide/reading/presto.py similarity index 58% rename from riptide/reading/presto.py rename to src/riptide/reading/presto.py index cea329a..522d36e 100644 --- a/riptide/reading/presto.py +++ b/src/riptide/reading/presto.py @@ -31,7 +31,7 @@ # [ADDITIONAL NOTES] -SEP = '=' +SEP = "=" SEP_COLUMN = 40 FAKE_TELESCOPE = "None (Artificial Data Set)" @@ -40,22 +40,22 @@ def parse_inf_value(line, vtype): # A "standard" .inf line has an '=' character at column 40 if not (len(line) > SEP_COLUMN and line[SEP_COLUMN] == SEP): raise ValueError(f"Expected '=' character at column {SEP_COLUMN}") - val = line[SEP_COLUMN+1:].strip() + val = line[SEP_COLUMN + 1 :].strip() return vtype(val) def parse_bool(s): - """ Convert '1' or '0' to boolean """ + """Convert '1' or '0' to boolean""" return int(s) != 0 def parse_int_pair(s): - a, b = s.split(',') + a, b = s.split(",") return int(a), int(b) def inf2dict(text): - lines = text.strip('\n').splitlines() + lines = text.strip("\n").splitlines() def p(n, vtype): return parse_inf_value(lines[n], vtype) @@ -65,55 +65,57 @@ def p(n, vtype): telescope = p(1, str) if telescope == FAKE_TELESCOPE: - raise ValueError("Reading data generated with PRESTO's makedata is not supported") + raise ValueError( + "Reading data generated with PRESTO's makedata is not supported" + ) # Parse block common to all EM bands items = { - 'basename': basename, - 'telescope': telescope, - 'instrument': p(2, str), - 'source_name': p(3, str), - 'raj': p(4, str), - 'decj': p(5, str), - 'observer': p(6, str), - 'mjd': p(7, float), - 'barycentered': p(8, parse_bool), - 'nsamp': p(9, int), - 'tsamp': p(10, float), - 'breaks': p(11, parse_bool), - 'onoff_pairs': [] + "basename": basename, + "telescope": telescope, + "instrument": p(2, str), + "source_name": p(3, str), + "raj": p(4, str), + "decj": p(5, str), + "observer": p(6, str), + "mjd": p(7, float), + "barycentered": p(8, parse_bool), + "nsamp": p(9, int), + "tsamp": p(10, float), + "breaks": p(11, parse_bool), + "onoff_pairs": [], } lines = lines[12:] # Parse on/off pairs, if any - if items['breaks']: + if items["breaks"]: for line in lines: try: pair = parse_inf_value(line, parse_int_pair) - items['onoff_pairs'].append(pair) + items["onoff_pairs"].append(pair) except: break - num_onoff = len(items['onoff_pairs']) + num_onoff = len(items["onoff_pairs"]) lines = lines[num_onoff:] em_band = p(0, str) - items['em_band'] = em_band - - if em_band == 'Radio': - items['fov_arcsec'] = p(1, float) - items['dm'] = p(2, float) - items['fbot'] = p(3, float) - items['bandwidth'] = p(4, float) - items['nchan'] = p(5, int) - items['cbw'] = p(6, float) - items['analyst'] = p(7, str) - - elif em_band in ('X-ray', 'Gamma'): - items['fov_arcsec'] = p(1, float) - items['central_energy_kev'] = p(2, float) - items['energy_bandpass_kev'] = p(3, float) - items['analyst'] = p(4, str) + items["em_band"] = em_band + + if em_band == "Radio": + items["fov_arcsec"] = p(1, float) + items["dm"] = p(2, float) + items["fbot"] = p(3, float) + items["bandwidth"] = p(4, float) + items["nchan"] = p(5, int) + items["cbw"] = p(6, float) + items["analyst"] = p(7, str) + + elif em_band in ("X-ray", "Gamma"): + items["fov_arcsec"] = p(1, float) + items["central_energy_kev"] = p(2, float) + items["energy_bandpass_kev"] = p(3, float) + items["analyst"] = p(4, str) else: raise ValueError(f"EM Band {em_band!r} not supported") @@ -122,28 +124,29 @@ def p(n, vtype): class PrestoInf(dict): - """ Parse PRESTO's .inf files that contain dedispersed time series metadata. """ + """Parse PRESTO's .inf files that contain dedispersed time series metadata.""" + def __init__(self, fname): self._fname = os.path.realpath(fname) - with open(fname, 'r') as fobj: + with open(fname, "r") as fobj: items = inf2dict(fobj.read()) super(PrestoInf, self).__init__(items) @property def fname(self): - """ Absolute path to original file. """ + """Absolute path to original file.""" return self._fname @property def data_fname(self): - """ Path to the associated .dat file """ - return self.fname.rsplit('.', maxsplit=1)[0] + '.dat' + """Path to the associated .dat file""" + return self.fname.rsplit(".", maxsplit=1)[0] + ".dat" @property def skycoord(self): - """ astropy.SkyCoord object with the coordinates of the source. """ - return SkyCoord(self['raj'], self['decj'], unit=(uu.hour, uu.degree)) + """astropy.SkyCoord object with the coordinates of the source.""" + return SkyCoord(self["raj"], self["decj"], unit=(uu.hour, uu.degree)) def load_data(self): - """ Returns the associated time series data as a numpy float32 array. """ + """Returns the associated time series data as a numpy float32 array.""" return numpy.fromfile(self.data_fname, dtype=numpy.float32) diff --git a/riptide/reading/sigproc.py b/src/riptide/reading/sigproc.py similarity index 71% rename from riptide/reading/sigproc.py rename to src/riptide/reading/sigproc.py index 3595d2e..410f393 100644 --- a/riptide/reading/sigproc.py +++ b/src/riptide/reading/sigproc.py @@ -1,4 +1,5 @@ -""" Read dedispersed time series from SIGPROC. """ +"""Read dedispersed time series from SIGPROC.""" + ##### Standard imports ##### import os import struct @@ -11,7 +12,7 @@ # SIGPROC keys and associated data types # Copied from Ewan Barr's sigpyproc -# NOTE: +# NOTE: # Values of type int are assumed to be stored as 32-bit in the header # Values of type float are assumed to be stored as 64-bit (C double) # Values of type bool are assumed to be stored as an unsigned char (8-bit) @@ -56,19 +57,18 @@ "obs_date": str, "obs_time": str, "accel": float, - "signed": bool - } + "signed": bool, +} # These flags mark the boundaries of the header in a SIGPROC data file -HEADER_START = 'HEADER_START' -HEADER_END = 'HEADER_END' - +HEADER_START = "HEADER_START" +HEADER_END = "HEADER_END" def read_str(fobj): - """ Read string from open binary file object. """ - size, = struct.unpack('i', fobj.read(4)) + """Read string from open binary file object.""" + (size,) = struct.unpack("i", fobj.read(4)) # NOTE: reading a string in a binary file via fobj.read() returns # a str in python2, but a bytes in python3. In python3 we need to called # bytes.decode() to get a str, but in python2 calling decode() gives us @@ -78,33 +78,35 @@ def read_str(fobj): def read_attribute(fobj, keydb): - """ Read SIGPROC {key, value} pair from open binary file object. """ + """Read SIGPROC {key, value} pair from open binary file object.""" key = read_str(fobj) if key == HEADER_END: return key, None atype = keydb.get(key, None) if atype is None: - errmsg = 'Type of SIGPROC header attribute \'{0:s}\' is unknown, please specify it'.format(key) + errmsg = "Type of SIGPROC header attribute '{0:s}' is unknown, please specify it".format( + key + ) raise KeyError(errmsg) if atype == str: val = read_str(fobj) elif atype == int: - val, = struct.unpack('i', fobj.read(4)) + (val,) = struct.unpack("i", fobj.read(4)) elif atype == float: - val, = struct.unpack('d', fobj.read(8)) + (val,) = struct.unpack("d", fobj.read(8)) elif atype == bool: - val, = struct.unpack('B', fobj.read(1)) # B = unsigned char + (val,) = struct.unpack("B", fobj.read(1)) # B = unsigned char val = bool(val) else: - errmsg = 'Key \'{0:s}\' has unsupported type \'{1:s}\''.format(key, atype) + errmsg = "Key '{0:s}' has unsupported type '{1:s}'".format(key, atype) raise ValueError(errmsg) return key, val def read_sigproc_header(fobj, extra_keys={}): - """ Read SIGPROC header from an open file object + """Read SIGPROC header from an open file object Parameters ---------- @@ -131,7 +133,9 @@ def read_sigproc_header(fobj, extra_keys={}): # Read HEADER_START flag fobj.seek(0) flag = read_str(fobj) - errmsg = 'File starts with \'{0:s}\' flag instead of the expected \'{1:s}\''.format(flag, HEADER_START) + errmsg = "File starts with '{0:s}' flag instead of the expected '{1:s}'".format( + flag, HEADER_START + ) assert flag == HEADER_START, errmsg # Read all header attributes @@ -146,52 +150,53 @@ def read_sigproc_header(fobj, extra_keys={}): def parse_float_coord(f): - """ Parse coordinate in SIGPROC's own decimal floating point, + """Parse coordinate in SIGPROC's own decimal floating point, to either hours (RA) or degrees (Dec). """ sign = np.sign(f) x = abs(f) - hh, x = divmod(x, 10000.) - mm, ss = divmod(x, 100.) + hh, x = divmod(x, 10000.0) + mm, ss = divmod(x, 100.0) return sign * (hh + mm / 60.0 + ss / 3600.0) class SigprocHeader(dict): - """ dict-like object wrapping the information carried by the header of a - SIGPROC file. """ + """dict-like object wrapping the information carried by the header of a + SIGPROC file.""" + def __init__(self, fname, extra_keys={}): self._fname = os.path.abspath(fname) - with open(self.fname, 'rb') as fobj: + with open(self.fname, "rb") as fobj: (attrs, self._bytesize) = read_sigproc_header(fobj, extra_keys) super(SigprocHeader, self).__init__(attrs) @property def fname(self): - """ Absolute path to original file. """ + """Absolute path to original file.""" return self._fname @property def bytesize(self): - """ Number of bytes occupied by the header in the original file. """ + """Number of bytes occupied by the header in the original file.""" return self._bytesize @property def bytes_per_sample(self): - return self['nchans'] * self['nbits'] // 8 + return self["nchans"] * self["nbits"] // 8 @property def nsamp(self): - """ Number of samples in the data """ + """Number of samples in the data""" return (os.path.getsize(self.fname) - self.bytesize) // self.bytes_per_sample @property def tobs(self): - """ Total length of the data in seconds """ - return self.nsamp * self['tsamp'] + """Total length of the data in seconds""" + return self.nsamp * self["tsamp"] @property def skycoord(self): - """ astropy.SkyCoord object with the coordinates of the source. """ - rajd = parse_float_coord(self['src_raj']) - dejd = parse_float_coord(self['src_dej']) - return SkyCoord(rajd, dejd, unit=(uu.hour, uu.degree), frame='icrs') + """astropy.SkyCoord object with the coordinates of the source.""" + rajd = parse_float_coord(self["src_raj"]) + dejd = parse_float_coord(self["src_dej"]) + return SkyCoord(rajd, dejd, unit=(uu.hour, uu.degree), frame="icrs") diff --git a/riptide/running_medians.py b/src/riptide/running_medians.py similarity index 97% rename from riptide/running_medians.py rename to src/riptide/running_medians.py index 97af9f6..8b53523 100644 --- a/riptide/running_medians.py +++ b/src/riptide/running_medians.py @@ -74,10 +74,12 @@ def fast_running_median(data, width_samples, min_points=101): raise ValueError("min_points must be an odd number") scrunch_factor = int(max(1, width_samples / float(min_points))) - if (scrunch_factor == 1): + if scrunch_factor == 1: return running_median(data, width_samples) scrunched_data = scrunch(data, scrunch_factor) rmed_lores = running_median(scrunched_data, min_points) - x_lores = np.arange(scrunched_data.size) * scrunch_factor + 0.5 * (scrunch_factor - 1) + x_lores = np.arange(scrunched_data.size) * scrunch_factor + 0.5 * ( + scrunch_factor - 1 + ) return np.interp(np.arange(data.size), x_lores, rmed_lores) diff --git a/riptide/search.py b/src/riptide/search.py similarity index 86% rename from riptide/search.py rename to src/riptide/search.py index c6d00ef..8f5bcb2 100644 --- a/riptide/search.py +++ b/src/riptide/search.py @@ -8,9 +8,21 @@ @timing -def ffa_search(tseries, period_min=1.0, period_max=30.0, fpmin=8, bins_min=240, bins_max=260, - ducy_max=0.20, wtsp=1.5, deredden=True, rmed_width=4.0, rmed_minpts=101, already_normalised=False): - """ +def ffa_search( + tseries, + period_min=1.0, + period_max=30.0, + fpmin=8, + bins_min=240, + bins_max=260, + ducy_max=0.20, + wtsp=1.5, + deredden=True, + rmed_width=4.0, + rmed_minpts=101, + already_normalised=False, +): + """ Run a FFA search of a single TimeSeries object, producing its periodogram. Parameters @@ -27,7 +39,7 @@ def ffa_search(tseries, period_min=1.0, period_max=30.0, fpmin=8, bins_min=240, bins_min : int Minimum number of phase bins in the folded data. Higher values provide better duty cycle resolution. As the code searches longer trial - periods, the data are iteratively downsampled so that the number of + periods, the data are iteratively downsampled so that the number of phase bins remains between bins_min and bins_max bins_max : int Maximum number of phase bins in the folded data. Must be strictly @@ -41,14 +53,14 @@ def ffa_search(tseries, period_min=1.0, period_max=30.0, fpmin=8, bins_min=240, phase bins): 1, 2, 3, 4, 6, 9, 13, 19 ... ducy_max : float Maximum duty cycle to optimally search. Limits the maximum width of the - boxcar matched filters applied to any given profile. - Example: on a 300 phase bin profile, ducy_max = 0.2 means that no + boxcar matched filters applied to any given profile. + Example: on a 300 phase bin profile, ducy_max = 0.2 means that no boxcar filter of width > 60 bins will be applied deredden : bool Subtract red noise from the time series before searching rmed_width : float The width of the running median filter to subtract from the input data - before processing, in seconds + before processing, in seconds rmed_minpts : int The running median is calculated of a time scrunched version of the input data to save time: rmed_minpts is the minimum number of @@ -56,7 +68,7 @@ def ffa_search(tseries, period_min=1.0, period_max=30.0, fpmin=8, bins_min=240, Lower values make the running median calculation less accurate but faster, due to allowing a higher scrunching factor already_normalised : bool - Assume that the data are already normalised to zero mean and unit + Assume that the data are already normalised to zero mean and unit standard deviation Returns @@ -64,7 +76,7 @@ def ffa_search(tseries, period_min=1.0, period_max=30.0, fpmin=8, bins_min=240, ts : TimeSeries The de-reddened and normalised time series that was actually searched pgram : Periodogram - The output of the search, which contains among other things a 2D array + The output of the search, which contains among other things a 2D array representing S/N as a function of trial period and trial width. """ ### Prepare data: deredden then normalise IN THAT ORDER @@ -75,8 +87,7 @@ def ffa_search(tseries, period_min=1.0, period_max=30.0, fpmin=8, bins_min=240, widths = generate_width_trials(bins_min, ducy_max=ducy_max, wtsp=wtsp) periods, foldbins, snrs = libcpp.periodogram( - tseries.data, tseries.tsamp, widths, - period_min, period_max, bins_min, bins_max - ) + tseries.data, tseries.tsamp, widths, period_min, period_max, bins_min, bins_max + ) pgram = Periodogram(widths, periods, foldbins, snrs, metadata=tseries.metadata) return tseries, pgram diff --git a/riptide/serialization.py b/src/riptide/serialization.py similarity index 66% rename from riptide/serialization.py rename to src/riptide/serialization.py index 122d778..17a1f5d 100644 --- a/riptide/serialization.py +++ b/src/riptide/serialization.py @@ -11,19 +11,20 @@ # NOTE: Using importlib avoids placing "import riptide" at the top of the file, # which solves circular import issues def get_riptide_version(): - riptide = importlib.import_module('riptide') - return getattr(riptide, '__version__') + riptide = importlib.import_module("riptide") + return getattr(riptide, "__version__") # NOTE: it is implicitly assumed that any JSON-serializable riptide class is in # the *base* riptide module def get_class(clsname): - riptide = importlib.import_module('riptide') + riptide = importlib.import_module("riptide") return getattr(riptide, clsname) class JSONEncoder(json.JSONEncoder): """ """ + def default(self, obj): # NOTE: this method is called only for types not supported by default # Since Metadata is a dict (supported by default), it *never* gets called @@ -35,10 +36,10 @@ def default(self, obj): b64_bytes = base64.b64encode(np.ascontiguousarray(obj).data) b64_str = b64_bytes.decode() return { - '__type__': 'numpy.ndarray', - 'data': b64_str, - 'dtype': str(obj.dtype), - 'shape': obj.shape + "__type__": "numpy.ndarray", + "data": b64_str, + "dtype": str(obj.dtype), + "shape": obj.shape, } # Handle numpy numeric types @@ -50,55 +51,55 @@ def default(self, obj): if isinstance(obj, pandas.DataFrame): return { - '__type__': 'pandas.DataFrame', - 'values': self.default(obj.values), - 'columns': list(obj.columns) + "__type__": "pandas.DataFrame", + "values": self.default(obj.values), + "columns": list(obj.columns), } if isinstance(obj, SkyCoord): return { - '__type__': 'astropy.SkyCoord', - 'rajd': obj.icrs.ra.deg, - 'decjd': obj.icrs.dec.deg, - 'frame': 'icrs' + "__type__": "astropy.SkyCoord", + "rajd": obj.icrs.ra.deg, + "decjd": obj.icrs.dec.deg, + "frame": "icrs", } - # If we reach that point, we assume that anything with a to_dict() method + # If we reach that point, we assume that anything with a to_dict() method # is a riptide serializable object - if hasattr(obj, 'to_dict'): + if hasattr(obj, "to_dict"): items = obj.to_dict() - items['__type__'] = type(obj).__name__ + items["__type__"] = type(obj).__name__ # Version the object properly - if hasattr(obj, 'version') and obj.version: - items['__version__'] = obj.version + if hasattr(obj, "version") and obj.version: + items["__version__"] = obj.version else: - items['__version__'] = get_riptide_version() + items["__version__"] = get_riptide_version() return items - + return super(JSONEncoder, self).default(obj) def object_hook(items): - if not '__type__' in items: + if not "__type__" in items: return items - typename = items['__type__'] + typename = items["__type__"] - if typename == 'numpy.ndarray': - b64_bytes = items['data'].encode() + if typename == "numpy.ndarray": + b64_bytes = items["data"].encode() data = base64.b64decode(b64_bytes) - return np.frombuffer(data, items['dtype']).reshape(items['shape']) + return np.frombuffer(data, items["dtype"]).reshape(items["shape"]) # NOTE: decoding is done from the deepest nodes of the tree first # Which means items['values'] is already a numpy array here - if typename == 'pandas.DataFrame': - return pandas.DataFrame(items['values'], columns=items['columns']) + if typename == "pandas.DataFrame": + return pandas.DataFrame(items["values"], columns=items["columns"]) - if typename == 'astropy.SkyCoord': - ra = items['rajd'] - dec = items['decjd'] - frame = items['frame'] + if typename == "astropy.SkyCoord": + ra = items["rajd"] + dec = items["decjd"] + frame = items["frame"] return SkyCoord(ra * uu.deg, dec * uu.deg, frame=frame) # If typename is not one of the above, then we assume it is @@ -108,8 +109,8 @@ def object_hook(items): obj = cls.from_dict(items) # Handle version - if '__version__' in items: - version = items['__version__'] + if "__version__" in items: + version = items["__version__"] else: version = get_riptide_version() obj.version = version @@ -130,8 +131,8 @@ def to_json(obj, **kwargs): are passed to json.dumps(). """ kwargs = dict(kwargs) - kwargs['cls'] = JSONEncoder - kwargs.setdefault('indent', 4) + kwargs["cls"] = JSONEncoder + kwargs.setdefault("indent", 4) return json.dumps(obj, **kwargs) @@ -139,14 +140,14 @@ def load_json(fname): """ Load a JSON file containing a riptide object (or list/dict/composition thereof) """ - with open(fname, 'r') as f: + with open(fname, "r") as f: return from_json(f.read()) def save_json(fname, obj, **kwargs): """ - Save riptide object (or list/dict/composition thereof) to a JSON file. Any keyword arguments are + Save riptide object (or list/dict/composition thereof) to a JSON file. Any keyword arguments are passed to json.dumps(). """ - with open(fname, 'w') as f: + with open(fname, "w") as f: f.write(to_json(obj, **kwargs)) diff --git a/riptide/time_series.py b/src/riptide/time_series.py similarity index 80% rename from riptide/time_series.py rename to src/riptide/time_series.py index ebeb294..6ee1911 100644 --- a/riptide/time_series.py +++ b/src/riptide/time_series.py @@ -14,7 +14,7 @@ class TimeSeries(object): - """ + """ Container for time series data to be searched with the FFA. **Use classmethods to create a new TimeSeries object.** @@ -37,6 +37,7 @@ class TimeSeries(object): TimeSeries.from_sigproc : Load dedispersed data produced by SIGPROC TimeSeries.generate : Generate a noisy time series containing a fake pulsar signal """ + def __init__(self, data, tsamp, metadata=None, copy=False): if copy: self._data = np.asarray(data, dtype=np.float32).copy() @@ -47,24 +48,24 @@ def __init__(self, data, tsamp, metadata=None, copy=False): # Carrying a tobs attribute is quite practical in later stages of # the pipeline (peak detection in periodograms) - self.metadata['tobs'] = self.length + self.metadata["tobs"] = self.length @property def data(self): - """ numpy array holding the time series data, in float32 format. """ + """numpy array holding the time series data, in float32 format.""" return self._data @property def tsamp(self): - """ Sampling time in seconds. """ + """Sampling time in seconds.""" return self._tsamp def copy(self): - """ Returns a new copy of the TimeSeries """ + """Returns a new copy of the TimeSeries""" return copy.deepcopy(self) def normalise(self, inplace=False): - """ Normalise to zero mean and unit variance. if 'inplace' is False, + """Normalise to zero mean and unit variance. if 'inplace' is False, a new TimeSeries object with the normalized data is returned. Parameters @@ -82,17 +83,18 @@ def normalise(self, inplace=False): # data have large values m = self.data.mean(dtype=np.float64) v = self.data.var(dtype=np.float64) - norm = v ** 0.5 - + norm = v**0.5 + if inplace: self._data = (self.data - m) / norm else: - return TimeSeries((self.data - m) / norm, self.tsamp, metadata=self.metadata) - + return TimeSeries( + (self.data - m) / norm, self.tsamp, metadata=self.metadata + ) @timing def deredden(self, width, minpts=101, inplace=False): - """ Subtract from the data an aproximate running median. To save time, + """Subtract from the data an aproximate running median. To save time, this running median is computed on a downsampled copy of the data, then upsampled back to the original resolution and finally subtracted from the original data. @@ -122,7 +124,7 @@ def deredden(self, width, minpts=101, inplace=False): return TimeSeries(self.data - rmed, self.tsamp, metadata=self.metadata) def downsample(self, factor, inplace=False): - """ Downsample data by a real-valued factor, by grouping and adding + """Downsample data by a real-valued factor, by grouping and adding together consecutive samples (or fractions of samples). Parameters @@ -142,7 +144,11 @@ def downsample(self, factor, inplace=False): self._data = downsample(self.data, factor) self._tsamp *= factor else: - return TimeSeries(downsample(self.data, factor), factor * self.tsamp, metadata=self.metadata) + return TimeSeries( + downsample(self.data, factor), + factor * self.tsamp, + metadata=self.metadata, + ) def fold(self, period, bins, subints=None): """ @@ -155,8 +161,8 @@ def fold(self, period, bins, subints=None): bins : int Number of phase bins subints: int or None, optional - Number of desired sub-integrations. If None, the number of - sub-integrations will be the number of full periods that fit in + Number of desired sub-integrations. If None, the number of + sub-integrations will be the number of full periods that fit in the data Returns @@ -168,8 +174,10 @@ def fold(self, period, bins, subints=None): return fold(self, period, bins, subints=subints) @classmethod - def generate(cls, length, tsamp, period, phi0=0.5, ducy=0.02, amplitude=10.0, stdnoise=1.0): - """ + def generate( + cls, length, tsamp, period, phi0=0.5, ducy=0.02, amplitude=10.0, stdnoise=1.0 + ): + """ Generate a time series containing a periodic signal with a von Mises pulse profile, and some background white noise (optional). @@ -189,12 +197,12 @@ def generate(cls, length, tsamp, period, phi0=0.5, ducy=0.02, amplitude=10.0, st True amplitude of the signal as defined in the reference paper. The *expectation* of the S/N of the generated signal is S/N_true = amplitude / stdnoise, - assuming that a matched filter with the exact shape of the pulse is - employed to measure S/N (here: von Mises with given duty cycle). + assuming that a matched filter with the exact shape of the pulse is + employed to measure S/N (here: von Mises with given duty cycle). riptide employs boxcar filters in the search, which results in a slight - S/N loss. See the reference paper for details. + S/N loss. See the reference paper for details. A further degradation will be observed on bright signals, because - they bias the estimation of the mean and standard deviation of the + they bias the estimation of the mean and standard deviation of the noise in a blind search. stdnoise : float, optional Standard deviation of the background noise (default: 1.0). @@ -207,19 +215,28 @@ def generate(cls, length, tsamp, period, phi0=0.5, ducy=0.02, amplitude=10.0, st """ nsamp = int(round(length / tsamp)) period_samples = period / tsamp - data = generate_signal(nsamp, period_samples, phi0=phi0, ducy=ducy, amplitude=amplitude, stdnoise=stdnoise) - metadata = Metadata({ - 'source_name': 'fake', - 'signal_shape': 'Von Mises', - 'signal_period': period, - 'signal_initial_phase': phi0, - 'signal_duty_cycle': ducy, - }) + data = generate_signal( + nsamp, + period_samples, + phi0=phi0, + ducy=ducy, + amplitude=amplitude, + stdnoise=stdnoise, + ) + metadata = Metadata( + { + "source_name": "fake", + "signal_shape": "Von Mises", + "signal_period": period, + "signal_initial_phase": phi0, + "signal_duty_cycle": ducy, + } + ) return cls(data, tsamp, copy=False, metadata=metadata) @classmethod def from_numpy_array(cls, array, tsamp, copy=False): - """ Create a new TimeSeries from a numpy array (or array-like). + """Create a new TimeSeries from a numpy array (or array-like). Parameters ---------- @@ -230,7 +247,7 @@ def from_numpy_array(cls, array, tsamp, copy=False): copy : bool, optional If set to True, the resulting time series will hold a new copy of 'array', otherwise it only holds a reference to it - + Returns ------- out: TimeSeries @@ -240,7 +257,7 @@ def from_numpy_array(cls, array, tsamp, copy=False): @classmethod def from_binary(cls, fname, tsamp, dtype=np.float32): - """ Create a new TimeSeries from a raw binary file, containing the + """Create a new TimeSeries from a raw binary file, containing the time series data without any header or footer. This will work as long as the data can be loaded with numpy.fromfile(). @@ -263,7 +280,7 @@ def from_binary(cls, fname, tsamp, dtype=np.float32): @classmethod def from_npy_file(cls, fname, tsamp): - """ Create a new TimeSeries from a .npy file, written with numpy.save(). + """Create a new TimeSeries from a .npy file, written with numpy.save(). Parameters ---------- @@ -283,7 +300,7 @@ def from_npy_file(cls, fname, tsamp): @classmethod @timing def from_presto_inf(cls, fname): - """ Create a new TimeSeries from a .inf file written by PRESTO. The + """Create a new TimeSeries from a .inf file written by PRESTO. The associated .dat file must be in the same directory. Parameters @@ -301,10 +318,10 @@ def from_presto_inf(cls, fname): # TODO: check that the number of samples read from the .inf file # matches what is actually in the .dat file, although the possibility of # 'data breaks' could make this difficult - ts = cls(inf.load_data(), tsamp=inf['tsamp'], metadata=metadata) + ts = cls(inf.load_data(), tsamp=inf["tsamp"], metadata=metadata) - em_band = ts.metadata['em_band'] - if em_band in ('X-ray', 'Gamma'): + em_band = ts.metadata["em_band"] + if em_band in ("X-ray", "Gamma"): msg = ( f" You have loaded file {fname!r}, which contains data observed at" f" a high-energy band {em_band!r}." @@ -318,7 +335,7 @@ def from_presto_inf(cls, fname): @classmethod @timing def from_sigproc(cls, fname, extra_keys={}): - """ Create a new TimeSeries from a file written by SIGPROC's dedisperse + """Create a new TimeSeries from a file written by SIGPROC's dedisperse routine. Parameters @@ -350,35 +367,37 @@ def from_sigproc(cls, fname, extra_keys={}): metadata = Metadata.from_sigproc(sig, extra_keys=extra_keys) # Load time series data - with open(fname, 'rb') as fobj: + with open(fname, "rb") as fobj: fobj.seek(sig.bytesize) - if metadata['nbits'] == 8: - dtype = np.int8 if metadata['signed'] else np.uint8 + if metadata["nbits"] == 8: + dtype = np.int8 if metadata["signed"] else np.uint8 # Don't forget to cast to float32 after reading ! data = np.fromfile(fobj, dtype=dtype).astype(np.float32) - else: # assume float32 + else: # assume float32 data = np.fromfile(fobj, dtype=np.float32) - return cls(data, tsamp=sig['tsamp'], metadata=metadata) + return cls(data, tsamp=sig["tsamp"], metadata=metadata) @property def nsamp(self): - """ Number of samples in the data. """ + """Number of samples in the data.""" return self.data.size @property def length(self): - """ Length of the data in seconds """ + """Length of the data in seconds""" return self.nsamp * self.tsamp @property def tobs(self): - """ Length of the data in seconds. Alias of property 'length'. """ + """Length of the data in seconds. Alias of property 'length'.""" return self.length def __str__(self): name = type(self).__name__ - out = '{name} {{nsamp = {x.nsamp:d}, tsamp = {x.tsamp:.4e}, tobs = {x.length:.3f}}}'.format(name=name, x=self) + out = "{name} {{nsamp = {x.nsamp:d}, tsamp = {x.tsamp:.4e}, tobs = {x.length:.3f}}}".format( + name=name, x=self + ) return out def __repr__(self): @@ -386,12 +405,9 @@ def __repr__(self): @classmethod def from_dict(cls, items): - return cls(items['data'], items['tsamp'], metadata=items['metadata'], copy=False) + return cls( + items["data"], items["tsamp"], metadata=items["metadata"], copy=False + ) def to_dict(self): - return { - 'data': self.data, - 'tsamp': self.tsamp, - 'metadata': self.metadata - } - + return {"data": self.data, "tsamp": self.tsamp, "metadata": self.metadata} diff --git a/riptide/timing.py b/src/riptide/timing.py similarity index 83% rename from riptide/timing.py rename to src/riptide/timing.py index 3731ca2..af131a6 100644 --- a/riptide/timing.py +++ b/src/riptide/timing.py @@ -6,11 +6,12 @@ def timing(func, *args, **kwargs): @wraps(func) def wrapped(*args, **kwargs): - log = logging.getLogger('riptide.timing') + log = logging.getLogger("riptide.timing") t0 = time.time() output = func(*args, **kwargs) t1 = time.time() dt = t1 - t0 log.debug("{!r} runtime: {:.2f} ms".format(func.__name__, dt * 1000.0)) return output - return wrapped \ No newline at end of file + + return wrapped diff --git a/riptide/tests/data/README.md b/tests/data/README.md similarity index 100% rename from riptide/tests/data/README.md rename to tests/data/README.md diff --git a/riptide/tests/data/fake_presto_radio.dat b/tests/data/fake_presto_radio.dat similarity index 100% rename from riptide/tests/data/fake_presto_radio.dat rename to tests/data/fake_presto_radio.dat diff --git a/riptide/tests/data/fake_presto_radio.inf b/tests/data/fake_presto_radio.inf similarity index 100% rename from riptide/tests/data/fake_presto_radio.inf rename to tests/data/fake_presto_radio.inf diff --git a/riptide/tests/data/fake_presto_radio_breaks.dat b/tests/data/fake_presto_radio_breaks.dat similarity index 100% rename from riptide/tests/data/fake_presto_radio_breaks.dat rename to tests/data/fake_presto_radio_breaks.dat diff --git a/riptide/tests/data/fake_presto_radio_breaks.inf b/tests/data/fake_presto_radio_breaks.inf similarity index 100% rename from riptide/tests/data/fake_presto_radio_breaks.inf rename to tests/data/fake_presto_radio_breaks.inf diff --git a/riptide/tests/data/fake_presto_xray.dat b/tests/data/fake_presto_xray.dat similarity index 100% rename from riptide/tests/data/fake_presto_xray.dat rename to tests/data/fake_presto_xray.dat diff --git a/riptide/tests/data/fake_presto_xray.inf b/tests/data/fake_presto_xray.inf similarity index 100% rename from riptide/tests/data/fake_presto_xray.inf rename to tests/data/fake_presto_xray.inf diff --git a/riptide/tests/data/fake_sigproc_float32.tim b/tests/data/fake_sigproc_float32.tim similarity index 100% rename from riptide/tests/data/fake_sigproc_float32.tim rename to tests/data/fake_sigproc_float32.tim diff --git a/riptide/tests/data/fake_sigproc_int8.tim b/tests/data/fake_sigproc_int8.tim similarity index 100% rename from riptide/tests/data/fake_sigproc_int8.tim rename to tests/data/fake_sigproc_int8.tim diff --git a/riptide/tests/data/fake_sigproc_uint8.tim b/tests/data/fake_sigproc_uint8.tim similarity index 100% rename from riptide/tests/data/fake_sigproc_uint8.tim rename to tests/data/fake_sigproc_uint8.tim diff --git a/riptide/tests/data/fake_sigproc_uint8_nosignedkey.tim b/tests/data/fake_sigproc_uint8_nosignedkey.tim similarity index 100% rename from riptide/tests/data/fake_sigproc_uint8_nosignedkey.tim rename to tests/data/fake_sigproc_uint8_nosignedkey.tim diff --git a/riptide/tests/pipeline_config_A.yml b/tests/pipeline_config_A.yml similarity index 100% rename from riptide/tests/pipeline_config_A.yml rename to tests/pipeline_config_A.yml diff --git a/riptide/tests/pipeline_config_B.yml b/tests/pipeline_config_B.yml similarity index 100% rename from riptide/tests/pipeline_config_B.yml rename to tests/pipeline_config_B.yml diff --git a/tests/presto_generation.py b/tests/presto_generation.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/presto_generation.py @@ -0,0 +1 @@ + diff --git a/riptide/tests/test_ffa_base_functions.py b/tests/test_ffa_base_functions.py similarity index 58% rename from riptide/tests/test_ffa_base_functions.py rename to tests/test_ffa_base_functions.py index 753b606..b7e6a1a 100644 --- a/riptide/tests/test_ffa_base_functions.py +++ b/tests/test_ffa_base_functions.py @@ -9,27 +9,31 @@ # Manually calculated 8x8 test case # This is expected to be invariant under phase rotation # and invariant when appending columns made of zeros -FFA_IN_88 = np.array([ - [0, 0, 0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 0, 0, 1] -]).astype(np.float32) - -FFA_OUT_88 = np.array([ - [0, 0, 0, 0, 0, 0, 0, 8], - [0, 0, 0, 0, 0, 0, 4, 4], - [0, 0, 0, 0, 0, 2, 4, 2], - [0, 0, 0, 0, 2, 2, 2, 2], - [0, 0, 0, 1, 2, 2, 2, 1], - [0, 0, 1, 2, 1, 1, 2, 1], - [0, 1, 1, 1, 2, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 1, 1] -]).astype(np.float32) +FFA_IN_88 = np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 0, 1], + ] +).astype(np.float32) + +FFA_OUT_88 = np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 8], + [0, 0, 0, 0, 0, 0, 4, 4], + [0, 0, 0, 0, 0, 2, 4, 2], + [0, 0, 0, 0, 2, 2, 2, 2], + [0, 0, 0, 1, 2, 2, 2, 1], + [0, 0, 1, 2, 1, 1, 2, 1], + [0, 1, 1, 1, 2, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1], + ] +).astype(np.float32) def test_transforms(): @@ -43,13 +47,9 @@ def test_transforms(): assert np.allclose(Z, truth) for extra_cols in range(8): - X = np.hstack([ - FFA_IN_88, np.zeros(shape=(8, extra_cols)) - ]) + X = np.hstack([FFA_IN_88, np.zeros(shape=(8, extra_cols))]) m, p = X.shape - truth = np.hstack([ - FFA_OUT_88, np.zeros(shape=(8, extra_cols)) - ]) + truth = np.hstack([FFA_OUT_88, np.zeros(shape=(8, extra_cols))]) Y = ffa2(X) Z = ffa1(X.ravel(), p) assert np.allclose(Y, truth) @@ -57,23 +57,23 @@ def test_transforms(): # Test ffa2() errors with raises(ValueError): - X = np.zeros(4) # 1D array + X = np.zeros(4) # 1D array ffa2(X) # Test ffa1() errors # Input not 1D with raises(ValueError): - X = np.zeros((4, 4)) # 2D array + X = np.zeros((4, 4)) # 2D array ffa1(X, 4) - + X = np.zeros(10) with raises(ValueError): - ffa1(X, X.size + 1) # p too large + ffa1(X, X.size + 1) # p too large with raises(ValueError): - ffa1(X, 4.0) # p not int - + ffa1(X, 4.0) # p not int + def test_ffafreq(): # NOTE: ffaprd() simply does 1.0 / ffafreq(), so we only need to properly @@ -84,10 +84,10 @@ def test_ffafreq(): m = 42 p = 127 N = m * p - dt = np.pi / 1000.0 # pi milliseconds sampling time, why not - + dt = np.pi / 1000.0 # pi milliseconds sampling time, why not + s = np.arange(m, dtype=float) - true_periods = p**2 / (p - s/(m-1.0)) * dt + true_periods = p**2 / (p - s / (m - 1.0)) * dt true_freqs = 1.0 / true_periods freqs = ffafreq(N, p, dt=dt) @@ -100,19 +100,19 @@ def test_ffafreq(): ### Errors with raises(ValueError): - ffafreq(0, p, dt=dt) # N <= 0 + ffafreq(0, p, dt=dt) # N <= 0 with raises(ValueError): - ffafreq(np.pi, p, dt=dt) # N != int + ffafreq(np.pi, p, dt=dt) # N != int with raises(ValueError): - ffafreq(N, 1, dt=dt) # p <= 1 + ffafreq(N, 1, dt=dt) # p <= 1 with raises(ValueError): - ffafreq(N, np.pi, dt=dt) # p != int + ffafreq(N, np.pi, dt=dt) # p != int with raises(ValueError): - ffafreq(N, N + 1, dt=dt) # p > n - + ffafreq(N, N + 1, dt=dt) # p > n + with raises(ValueError): - ffafreq(N, p, dt=0) # dt <= 0 + ffafreq(N, p, dt=0) # dt <= 0 diff --git a/riptide/tests/test_ffa_search_pgram.py b/tests/test_ffa_search_pgram.py similarity index 70% rename from riptide/tests/test_ffa_search_pgram.py rename to tests/test_ffa_search_pgram.py index 1286834..5768ae0 100644 --- a/riptide/tests/test_ffa_search_pgram.py +++ b/tests/test_ffa_search_pgram.py @@ -3,14 +3,12 @@ import numpy as np import matplotlib.pyplot as plt -from pytest import raises - -from riptide import TimeSeries, ffa_search, Periodogram, save_json, load_json +from riptide import TimeSeries, ffa_search, save_json, load_json def test_ffa_search(): - # NOTE: we chose a length long enough so that running the - # 'periodogram pruning' function was actually necessary + # NOTE: we chose a length long enough so that running the + # 'periodogram pruning' function was actually necessary # (and thus the function gets properly covered by the tests) length = 200.0 tsamp = 0.001 @@ -23,13 +21,15 @@ def test_ffa_search(): period_min = 0.8 * period period_max = 1.2 * period tsdr, pgram = ffa_search( - ts, - period_min=period_min, period_max=period_max, - bins_min=bins_min, bins_max=bins_max + ts, + period_min=period_min, + period_max=period_max, + bins_min=bins_min, + bins_max=bins_max, ) # check trial periods are increasing - assert all(np.maximum.accumulate(pgram.periods) == pgram.periods) + assert all(np.maximum.accumulate(pgram.periods) == pgram.periods) assert pgram.snrs.shape == (len(pgram.periods), len(pgram.widths)) assert pgram.metadata == ts.metadata == tsdr.metadata assert pgram.tobs == length @@ -39,16 +39,18 @@ def test_ffa_search(): # returns a reference to the input TimeSeries (data left untouched) # This is how ffa_search() is called by the pipeline tsdr, pgram = ffa_search( - ts, - period_min=period_min, period_max=period_max, - bins_min=bins_min, bins_max=bins_max, - already_normalised=True, deredden=False + ts, + period_min=period_min, + period_max=period_max, + bins_min=bins_min, + bins_max=bins_max, + already_normalised=True, + deredden=False, ) assert id(tsdr) == id(ts) - ### Periodogram serialization ### - with tempfile.NamedTemporaryFile(suffix='.json') as f: + with tempfile.NamedTemporaryFile(suffix=".json") as f: save_json(f.name, pgram) f.flush() pgram_copy = load_json(f.name) @@ -57,19 +59,18 @@ def test_ffa_search(): assert np.allclose(pgram.widths, pgram_copy.widths) assert pgram.metadata == pgram_copy.metadata - ### Periodogram plotting ### - plt.switch_backend('Agg') + plt.switch_backend("Agg") fig = plt.figure(figsize=(20, 5), dpi=100) pgram.plot() - with tempfile.NamedTemporaryFile(suffix='.png') as fobj: + with tempfile.NamedTemporaryFile(suffix=".png") as fobj: plt.savefig(fobj.name) plt.close(fig) # Same with iwidth = 0 fig = plt.figure(figsize=(20, 5), dpi=100) pgram.plot(iwidth=0) - with tempfile.NamedTemporaryFile(suffix='.png') as fobj: + with tempfile.NamedTemporaryFile(suffix=".png") as fobj: plt.savefig(fobj.name) plt.close(fig) @@ -90,7 +91,9 @@ def test_ffa_search_no_downsampling(): period_min = bins_min * tsamp period_max = bins_max * tsamp ffa_search( - ts, - period_min=period_min, period_max=period_max, - bins_min=bins_min, bins_max=bins_max + ts, + period_min=period_min, + period_max=period_max, + bins_min=bins_min, + bins_max=bins_max, ) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py new file mode 100644 index 0000000..9b5e6d5 --- /dev/null +++ b/tests/test_pipeline.py @@ -0,0 +1,246 @@ +import os +import glob +import tempfile +from copy import deepcopy + +import yaml +import numpy as np +from pytest import raises +from riptide import load_json +from riptide import TimeSeries +from riptide.pipeline.pipeline import get_parser, run_program +from riptide.pipeline.config_validation import InvalidPipelineConfig, InvalidSearchRange + +# NOTE 1: +# pipeline uses multiprocessing, to get proper coverage stats we need: +# * A .coveragerc file with the following options: +# [run] +# branch = True +# parallel = True +# concurrency = multiprocessing +# * Ensure that all instances of multiprocessing.Pool() have been closed and joined, as follows: +# >> pool.close() +# >> pool.join() + +# NOTE 2: +# To print logging output in full, call pytest like this: +# pytest --capture=no -o log_cli=True + +# NOTE 3: +# To get coverage stats, run this in the base riptide directory: +# coverage run -m pytest && coverage combine && coverage report -m --omit src/riptide/_version.py + + +SIGNAL_PERIOD = 1.0 +DATA_TOBS = 128.0 +DATA_TSAMP = 256e-6 + + +INF_TEMPLATE = """ + Data file name without suffix = {basename:s} + Telescope used = Parkes + Instrument used = Multibeam + Object being observed = Pulsar + J2000 Right Ascension (hh:mm:ss.ssss) = 00:00:01.0000 + J2000 Declination (dd:mm:ss.ssss) = -00:00:01.0000 + Data observed by = Kenji Oba + Epoch of observation (MJD) = 59000.000000 + Barycentered? (1=yes, 0=no) = 1 + Number of bins in the time series = {nsamp:d} + Width of each time series bin (sec) = {tsamp:.12e} + Any breaks in the data? (1=yes, 0=no) = 0 + Type of observation (EM band) = Radio + Beam diameter (arcsec) = 981 + Dispersion measure (cm-3 pc) = {dm:.12f} + Central freq of low channel (Mhz) = 1182.1953125 + Total bandwidth (Mhz) = 400 + Number of channels = 1024 + Channel bandwidth (Mhz) = 0.390625 + Data analyzed by = Space Sheriff Gavan + Any additional notes: + Input filterbank samples have 2 bits. +""" + + +def generate_data_presto( + outdir, + basename, + tobs=128.0, + tsamp=256e-6, + period=1.0, + dm=0.0, + amplitude=20.0, + ducy=0.05, +): + """ + Generate some time series data with a fake signal, and save it in PRESTO + inf/dat format in the specified output directory. + + Parameters + ---------- + outdir : str + Path to the output directory + basename : str + Base file name (not path) under which the .inf and .dat files + will be saved. + **kwargs: self-explanatory + """ + ### IMPORTANT: seed the RNG to get reproducible results ### + np.random.seed(0) + + ts = TimeSeries.generate( + tobs, tsamp, period, amplitude=amplitude, ducy=ducy, stdnoise=1.0 + ) + inf_text = INF_TEMPLATE.format( + basename=basename, nsamp=ts.nsamp, tsamp=tsamp, dm=dm + ) + + inf_path = os.path.join(outdir, f"{basename}.inf") + dat_path = os.path.join(outdir, f"{basename}.dat") + with open(inf_path, "w") as fobj: + fobj.write(inf_text) + ts.data.tofile(dat_path) + + +def runner_presto_fakepsr(fname_conf, outdir): + # Write test data + # NOTE: generate a signal bright enough to get harmonics and thus make sure + # that the harmonic filter gets to run + params = [ + # (dm, amplitude, ducy) + (0.0, 10.0, 0.05), + (10.0, 20.0, 0.02), + (20.0, 10.0, 0.05), + ] + + for dm, amplitude, ducy in params: + basename = f"fake_DM{dm:.3f}" + generate_data_presto( + outdir, + basename, + tobs=DATA_TOBS, + tsamp=DATA_TSAMP, + period=SIGNAL_PERIOD, + dm=dm, + amplitude=amplitude, + ducy=ducy, + ) + + ### Run pipeline ### + files = glob.glob(f"{outdir}/*.inf") + cmdline_args = ["--config", fname_conf, "--outdir", outdir] + files + parser = get_parser() + args = parser.parse_args(cmdline_args) + run_program(args) + + ### Check output sanity ### + topcand_fname = f"{outdir}/candidate_0000.json" + assert os.path.isfile(topcand_fname) + + topcand = load_json(topcand_fname) + + # NOTE: these checks depend on the RNG seed and the pipeline config + assert abs(topcand.params["period"] - SIGNAL_PERIOD) < 1.00e-4 + assert topcand.params["dm"] == 10.0 + assert topcand.params["width"] == 13 + assert abs(topcand.params["snr"] - 18.5) < 0.15 + + +def runner_presto_purenoise(fname_conf, outdir): + """ + Check that pipeline runs well even if no candidates are found + """ + dm = 0.0 + basename = f"purenoise_DM{dm:.3f}" + generate_data_presto( + outdir, + basename, + tobs=DATA_TOBS, + tsamp=DATA_TSAMP, + period=SIGNAL_PERIOD, + dm=dm, + amplitude=0.0, + ) + + ### Run pipeline ### + files = glob.glob(f"{outdir}/*.inf") + cmdline_args = ["--config", fname_conf, "--outdir", outdir] + files + parser = get_parser() + args = parser.parse_args(cmdline_args) + run_program(args) + + ### Check output sanity ### + assert not glob.glob(f"{outdir}/*.json") + assert not glob.glob(f"{outdir}/*.png") + + +def load_yaml(fname): + with open(fname, "r") as fobj: + return yaml.safe_load(fobj) + + +def save_yaml(items, fname): + with open(fname, "w") as fobj: + return yaml.safe_dump(items, fobj) + + +def test_pipeline_presto_fakepsr(): + # NOTE: outdir is a full path (str) + with tempfile.TemporaryDirectory() as outdir: + fname_conf = os.path.join(os.path.dirname(__file__), "pipeline_config_A.yml") + runner_presto_fakepsr(fname_conf, outdir) + + with tempfile.TemporaryDirectory() as outdir: + fname_conf = os.path.join(os.path.dirname(__file__), "pipeline_config_B.yml") + runner_presto_fakepsr(fname_conf, outdir) + + +def test_pipeline_presto_purenoise(): + with tempfile.TemporaryDirectory() as outdir: + fname_conf = os.path.join(os.path.dirname(__file__), "pipeline_config_A.yml") + runner_presto_purenoise(fname_conf, outdir) + + with tempfile.TemporaryDirectory() as outdir: + fname_conf = os.path.join(os.path.dirname(__file__), "pipeline_config_B.yml") + runner_presto_purenoise(fname_conf, outdir) + + +def test_config_validation(): + fname_conf = os.path.join(os.path.dirname(__file__), "pipeline_config_A.yml") + conf_correct = load_yaml(fname_conf) + + # Wrong parameter type + with tempfile.TemporaryDirectory() as outdir: + conf_wrong = deepcopy(conf_correct) + conf_wrong["dmselect"]["min"] = "LOL" + tmp = os.path.join(outdir, "wrong_config.yaml") + save_yaml(conf_wrong, tmp) + with raises(InvalidPipelineConfig): + runner_presto_fakepsr(tmp, outdir) + + # period_min too low + with tempfile.TemporaryDirectory() as outdir: + conf_wrong = deepcopy(conf_correct) + conf_wrong["ranges"][0]["ffa_search"]["period_min"] = 1.0e-9 + tmp = os.path.join(outdir, "wrong_config.yaml") + save_yaml(conf_wrong, tmp) + with raises(InvalidSearchRange): + runner_presto_fakepsr(tmp, outdir) + + # too many phase bins requested to fold candidates + with tempfile.TemporaryDirectory() as outdir: + conf_wrong = deepcopy(conf_correct) + conf_wrong["ranges"][0]["candidates"]["bins"] = int(42.0e9) + tmp = os.path.join(outdir, "wrong_config.yaml") + save_yaml(conf_wrong, tmp) + with raises(InvalidSearchRange): + runner_presto_fakepsr(tmp, outdir) + + # non-contiguous search ranges + with tempfile.TemporaryDirectory() as outdir: + conf_wrong = deepcopy(conf_correct) + conf_wrong["ranges"][0]["ffa_search"]["period_max"] = 0.50042 + tmp = os.path.join(outdir, "wrong_config.yaml") + save_yaml(conf_wrong, tmp) + with raises(InvalidSearchRange): + runner_presto_fakepsr(tmp, outdir) diff --git a/tests/test_rseek.py b/tests/test_rseek.py new file mode 100644 index 0000000..ffd3bf9 --- /dev/null +++ b/tests/test_rseek.py @@ -0,0 +1,145 @@ +import os +import tempfile + +import numpy as np +from riptide import TimeSeries +from riptide.apps.rseek import get_parser, run_program + + +SIGNAL_PERIOD = 1.0 +SIGNAL_FREQ = 1.0 / SIGNAL_PERIOD +DATA_TOBS = 128.0 +DATA_TSAMP = 256e-6 + +PARSER = get_parser() +EXPECTED_COLUMNS = {"period", "freq", "width", "ducy", "dm", "snr"} +DEFAULT_OPTIONS = dict( + Pmin=0.5, Pmax=2.0, bmin=480, bmax=520, smin=7.0, format="presto" +) + +INF_TEMPLATE = """ + Data file name without suffix = {basename:s} + Telescope used = Parkes + Instrument used = Multibeam + Object being observed = Pulsar + J2000 Right Ascension (hh:mm:ss.ssss) = 00:00:01.0000 + J2000 Declination (dd:mm:ss.ssss) = -00:00:01.0000 + Data observed by = Kenji Oba + Epoch of observation (MJD) = 59000.000000 + Barycentered? (1=yes, 0=no) = 1 + Number of bins in the time series = {nsamp:d} + Width of each time series bin (sec) = {tsamp:.12e} + Any breaks in the data? (1=yes, 0=no) = 0 + Type of observation (EM band) = Radio + Beam diameter (arcsec) = 981 + Dispersion measure (cm-3 pc) = {dm:.12f} + Central freq of low channel (Mhz) = 1182.1953125 + Total bandwidth (Mhz) = 400 + Number of channels = 1024 + Channel bandwidth (Mhz) = 0.390625 + Data analyzed by = Space Sheriff Gavan + Any additional notes: + Input filterbank samples have 2 bits. +""" + + +def generate_data_presto( + outdir, + basename, + tobs=128.0, + tsamp=256e-6, + period=1.0, + dm=0.0, + amplitude=20.0, + ducy=0.05, +): + """ + Generate some time series data with a fake signal, and save it in PRESTO + inf/dat format in the specified output directory. + + Parameters + ---------- + outdir : str + Path to the output directory + basename : str + Base file name (not path) under which the .inf and .dat files + will be saved. + **kwargs: self-explanatory + """ + ### IMPORTANT: seed the RNG to get reproducible results ### + np.random.seed(0) + + ts = TimeSeries.generate( + tobs, tsamp, period, amplitude=amplitude, ducy=ducy, stdnoise=1.0 + ) + inf_text = INF_TEMPLATE.format( + basename=basename, nsamp=ts.nsamp, tsamp=tsamp, dm=dm + ) + + inf_path = os.path.join(outdir, f"{basename}.inf") + dat_path = os.path.join(outdir, f"{basename}.dat") + with open(inf_path, "w") as fobj: + fobj.write(inf_text) + ts.data.tofile(dat_path) + + +def dict2args(d): + """ + Convert dictionary of options to command line argument list + """ + args = [] + for k, v in d.items(): + args.append(f"--{k}") + args.append(str(v)) + return args + + +def test_rseek_fakepsr(): + with tempfile.TemporaryDirectory() as outdir: + generate_data_presto( + outdir, + "data", + tobs=DATA_TOBS, + tsamp=DATA_TSAMP, + period=SIGNAL_PERIOD, + dm=0.0, + amplitude=20.0, + ducy=0.02, + ) + fname = os.path.join(outdir, "data.inf") + cmdline_args = dict2args(DEFAULT_OPTIONS) + [fname] + args = PARSER.parse_args(cmdline_args) + df = run_program(args) + + assert df is not None + assert set(df.columns) == EXPECTED_COLUMNS + + # Results must be sorted in decreasing S/N order + assert all(df.snr == df.sort_values("snr", ascending=False).snr) + + # Check parameters of the top candidate + # NOTE: these checks depend on the RNG seed and the program options + topcand = df.iloc[0] + assert abs(topcand.freq - SIGNAL_FREQ) < 0.1 / DATA_TOBS + assert abs(topcand.snr - 18.5) < 0.15 + assert topcand.dm == 0 + assert topcand.width == 13 + + +def test_rseek_purenoise(): + with tempfile.TemporaryDirectory() as outdir: + generate_data_presto( + outdir, + "data", + tobs=DATA_TOBS, + tsamp=DATA_TSAMP, + period=SIGNAL_PERIOD, + dm=0.0, + amplitude=0.0, + ) + fname = os.path.join(outdir, "data.inf") + cmdline_args = dict2args(DEFAULT_OPTIONS) + [fname] + args = PARSER.parse_args(cmdline_args) + df = run_program(args) + + assert df is None diff --git a/riptide/tests/test_running_median.py b/tests/test_running_median.py similarity index 78% rename from riptide/tests/test_running_median.py rename to tests/test_running_median.py index ea19553..522336d 100644 --- a/riptide/tests/test_running_median.py +++ b/tests/test_running_median.py @@ -9,29 +9,29 @@ def running_median_naive(data, w): # The C++ running median implicitly pads both ends of the arrray with # the edge values, we reproduce this behaviour here - padded_data = np.pad(data, (h, h), mode='edge') + padded_data = np.pad(data, (h, h), mode="edge") def med(imid): - return np.median(padded_data[imid:imid+w]) + return np.median(padded_data[imid : imid + w]) return np.asarray([med(i) for i in range(data.size)]) def test_rmed_exceptions(): - data = np.arange(10, dtype='float32') + data = np.arange(10, dtype="float32") - with raises(ValueError): # width must be odd + with raises(ValueError): # width must be odd running_median(data, 2) - - with raises(ValueError): # width must be < size + + with raises(ValueError): # width must be < size running_median(data, data.size) - with raises(ValueError): # data must be 1D + with raises(ValueError): # data must be 1D running_median(np.zeros(shape=(4, 8)), 3) def test_rmed(): - x = np.random.normal(size=100).astype('float32') + x = np.random.normal(size=100).astype("float32") widths = [1, 3, 5, 7, 11, 25, 37] for w in widths: @@ -42,7 +42,7 @@ def test_rmed_non_contiguous_data(): # Test added after realizing that passing non-memory-contiguous array slices # to running_median() returned incorrect results # Thanks to Akshay Suresh for finding and reporting the problem - data = np.random.normal(size=300).reshape(100, 3).astype('float32') + data = np.random.normal(size=300).reshape(100, 3).astype("float32") widths = [1, 3, 5, 7, 11, 25, 37] # Calculate running median of columns @@ -70,5 +70,5 @@ def test_fast_rmed_noscrunch(): for w in widths: assert np.array_equal( running_median(x, w), - fast_running_median(x, w, min_points=min_points) - ) \ No newline at end of file + fast_running_median(x, w, min_points=min_points), + ) diff --git a/riptide/tests/test_snr.py b/tests/test_snr.py similarity index 81% rename from riptide/tests/test_snr.py rename to tests/test_snr.py index 3a5d18f..33a1fe6 100644 --- a/riptide/tests/test_snr.py +++ b/tests/test_snr.py @@ -1,13 +1,13 @@ -from pytest import raises import numpy as np +from pytest import raises from riptide import boxcar_snr def test_errors(): cols = 32 - data = np.zeros(cols, dtype='float32') - + data = np.zeros(cols, dtype="float32") + # No widths < 1 with raises(ValueError): boxcar_snr(data, [0, 1]) @@ -26,21 +26,21 @@ def test_output_dims(): widths = [1, 2, 3, 5] # 1D input - data = np.zeros(cols, dtype='float32') + data = np.zeros(cols, dtype="float32") snr = boxcar_snr(data, widths) assert snr.ndim == 1 assert snr.size == len(widths) # 2D input rows = 4 - data = np.zeros((rows, cols), dtype='float32') + data = np.zeros((rows, cols), dtype="float32") snr = boxcar_snr(data, widths) assert snr.ndim == 2 assert snr.shape == (rows, len(widths)) # 3D input layers = 3 - data = np.zeros((layers, rows, cols), dtype='float32') + data = np.zeros((layers, rows, cols), dtype="float32") snr = boxcar_snr(data, widths) assert snr.ndim == 3 assert snr.shape == (layers, rows, len(widths)) @@ -51,7 +51,7 @@ def test_phase_rotation_invariance(): cols = 32 widths = [1, 2, 5, 11, 18, 31] - data = np.random.normal(size=rows*cols).reshape(rows, cols).astype('float32') + data = np.random.normal(size=rows * cols).reshape(rows, cols).astype("float32") snr_ref = boxcar_snr(data, widths) for shift in range(1, cols + 1): @@ -62,9 +62,9 @@ def test_phase_rotation_invariance(): def test_values(): n = 64 widths = np.arange(1, n) - data = np.zeros(n, dtype='float32') + data = np.zeros(n, dtype="float32") - for w in range(1, n): # w = true width + for w in range(1, n): # w = true width data[:w] = 1.0 snr = boxcar_snr(data, widths) assert snr.argmax() == w - 1 diff --git a/riptide/tests/test_time_series.py b/tests/test_time_series.py similarity index 80% rename from riptide/tests/test_time_series.py rename to tests/test_time_series.py index cc46d84..c6567c5 100644 --- a/riptide/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -6,9 +6,10 @@ from riptide import TimeSeries, save_json, load_json -DATA_DIR = os.path.join(os.path.dirname(__file__), 'data') +DATA_DIR = os.path.join(os.path.dirname(__file__), "data") FLOAT_ATOL = 1.0e-6 + # NOTE: a TimeSeries has only two basic attributes: data and tsamp # *** That's only what we test here *** # Anything else is handled by the Metadata class @@ -22,28 +23,28 @@ def check_data(ts, refdata): # The actual data expected to be in all test .dat files refdata = np.arange(16) - fname = os.path.join(DATA_DIR, 'fake_presto_radio.inf') + fname = os.path.join(DATA_DIR, "fake_presto_radio.inf") ts = TimeSeries.from_presto_inf(fname) check_data(ts, refdata) - fname = os.path.join(DATA_DIR, 'fake_presto_radio_breaks.inf') + fname = os.path.join(DATA_DIR, "fake_presto_radio_breaks.inf") ts = TimeSeries.from_presto_inf(fname) check_data(ts, refdata) # Calling TimeSeries.from_presto_inf() on X-ray and Gamma data should raise a warning # about the noise stats being non-Gaussian with warns(UserWarning): - fname = os.path.join(DATA_DIR, 'fake_presto_xray.inf') + fname = os.path.join(DATA_DIR, "fake_presto_xray.inf") ts = TimeSeries.from_presto_inf(fname) check_data(ts, refdata) def test_sigproc(): - refdata = np.arange(16) # what is supposed to be in the data + refdata = np.arange(16) # what is supposed to be in the data filenames = [ - 'fake_sigproc_float32.tim', - 'fake_sigproc_uint8.tim', - 'fake_sigproc_int8.tim', + "fake_sigproc_float32.tim", + "fake_sigproc_uint8.tim", + "fake_sigproc_int8.tim", ] for fname in filenames: @@ -57,7 +58,7 @@ def test_sigproc(): # Check that trying to read 8-bit SIGPROC data without a 'signed' # header key raises an error with raises(ValueError): - fname = os.path.join(DATA_DIR, 'fake_sigproc_uint8_nosignedkey.tim') + fname = os.path.join(DATA_DIR, "fake_sigproc_uint8_nosignedkey.tim") ts = TimeSeries.from_sigproc(fname) @@ -74,13 +75,13 @@ def check_ts(ts): ts = TimeSeries.from_numpy_array(refdata, tsamp) check_ts(ts) - with tempfile.NamedTemporaryFile(suffix='.npy') as f: + with tempfile.NamedTemporaryFile(suffix=".npy") as f: # re-creates the file, still gets deleted on exiting context mgr np.save(f.name, refdata) ts = TimeSeries.from_npy_file(f.name, tsamp) check_ts(ts) - with tempfile.NamedTemporaryFile(suffix='.bin') as f: + with tempfile.NamedTemporaryFile(suffix=".bin") as f: # re-creates the file, still gets deleted on exiting context mgr refdata.astype(np.float32).tofile(f.name) ts = TimeSeries.from_binary(f.name, tsamp) @@ -88,9 +89,9 @@ def check_ts(ts): def test_generate(): - length = 10.0 # s - tsamp = 0.01 # s - period = 1.0 # s + length = 10.0 # s + tsamp = 0.01 # s + period = 1.0 # s amplitude = 25.0 # Generate noiseless data to check its amplitude @@ -99,7 +100,7 @@ def test_generate(): assert ts.length == length assert ts.tsamp == tsamp assert ts.data.dtype == np.float32 - assert np.allclose(sum(ts.data ** 2) ** 0.5, amplitude, atol=FLOAT_ATOL) + assert np.allclose(sum(ts.data**2) ** 0.5, amplitude, atol=FLOAT_ATOL) def test_methods(): @@ -107,13 +108,15 @@ def test_methods(): NOTE: This tests that the code runs, but not the output data quality, i.e. if dereddening removes low-frequency noise well """ - length = 10.0 # s - tsamp = 1.0e-3 # s - period = 1.0 # s + length = 10.0 # s + tsamp = 1.0e-3 # s + period = 1.0 # s amplitude = 25.0 stdnoise = 1.0 - tsorig = TimeSeries.generate(length, tsamp, period, amplitude=amplitude, stdnoise=stdnoise) + tsorig = TimeSeries.generate( + length, tsamp, period, amplitude=amplitude, stdnoise=stdnoise + ) ts = tsorig.copy() ### Normalisation inplace / out of place ### @@ -150,10 +153,10 @@ def test_methods(): assert tscopy.nsamp == tsorig.nsamp // dsfactor assert tscopy.length == tsorig.length - with raises(ValueError): # stricly < 1 + with raises(ValueError): # stricly < 1 ts = tsorig.downsample(0.55) - with raises(ValueError): # excessive + with raises(ValueError): # excessive ts = tsorig.downsample(tsorig.nsamp * 10) ### Folding ### @@ -180,7 +183,7 @@ def test_methods(): assert np.allclose(prof, X10.sum(axis=0), atol=FLOAT_ATOL) assert np.allclose(prof, Xm.sum(axis=0), atol=FLOAT_ATOL) - # Too many requested subints + # Too many requested subints with raises(ValueError): Xerr = tsorig.fold(1.0, bins, subints=1000000) @@ -188,7 +191,7 @@ def test_methods(): with raises(ValueError): Xerr = tsorig.fold(1.0, bins, subints=0) - # Too many requested bins + # Too many requested bins with raises(ValueError): Xerr = tsorig.fold(1.0, 1000000, subints=None) @@ -202,15 +205,21 @@ def test_methods(): def test_serialization(): - length = 10.0 # s - tsamp = 1.0e-3 # s - period = 1.0 # s + length = 10.0 # s + tsamp = 1.0e-3 # s + period = 1.0 # s amplitude = 25.0 stdnoise = 1.0 - ts = TimeSeries.generate(length, tsamp, period, amplitude=amplitude, stdnoise=stdnoise) + ts = TimeSeries.generate( + length, + tsamp, + period, + stdnoise=stdnoise, + amplitude=amplitude, + ) - with tempfile.NamedTemporaryFile(suffix='.json') as f: + with tempfile.NamedTemporaryFile(suffix=".json") as f: save_json(f.name, ts) tscopy = load_json(f.name) @@ -218,4 +227,3 @@ def test_serialization(): assert ts.nsamp == tscopy.nsamp assert ts.length == tscopy.length assert np.allclose(ts.data, tscopy.data, atol=FLOAT_ATOL) - From bff9d77551aae4576f4b60a3bf9679e27abc1ff7 Mon Sep 17 00:00:00 2001 From: Ujjwal Panda Date: Thu, 10 Apr 2025 00:50:50 +0530 Subject: [PATCH 2/3] Remove `_version.py` + add it to `.gitignore`. --- .gitignore | 24 +++++++++++++++--------- src/riptide/_version.py | 21 --------------------- 2 files changed, 15 insertions(+), 30 deletions(-) delete mode 100644 src/riptide/_version.py diff --git a/.gitignore b/.gitignore index 87509bb..ae8e5ba 100644 --- a/.gitignore +++ b/.gitignore @@ -1,13 +1,19 @@ -*.o -*.so -*.pyc -__pycache__ +dist +build +.egg +.nox .vscode -*.egg-info .coverage -.eggs +__pycache__ +.mypy_cache +docs/_build .pytest_cache -tmp -build -dist pip-wheel-metadata +.ipynb_checkpoints +src/riptide/_version.py + +*.o +*.so +*.egg +*.py[cod] +*.egg-info diff --git a/src/riptide/_version.py b/src/riptide/_version.py deleted file mode 100644 index 065144e..0000000 --- a/src/riptide/_version.py +++ /dev/null @@ -1,21 +0,0 @@ -# file generated by setuptools-scm -# don't change, don't track in version control - -__all__ = ["__version__", "__version_tuple__", "version", "version_tuple"] - -TYPE_CHECKING = False -if TYPE_CHECKING: - from typing import Tuple - from typing import Union - - VERSION_TUPLE = Tuple[Union[int, str], ...] -else: - VERSION_TUPLE = object - -version: str -__version__: str -__version_tuple__: VERSION_TUPLE -version_tuple: VERSION_TUPLE - -__version__ = version = "0.2.6.dev5+g054e998.d20250409" -__version_tuple__ = version_tuple = (0, 2, 6, "dev5", "g054e998.d20250409") From eb95b6d7d71587874b9898fb37e3b69c298fc655 Mon Sep 17 00:00:00 2001 From: Ujjwal Panda Date: Fri, 18 Apr 2025 18:45:37 +0530 Subject: [PATCH 3/3] Fix duplicated code in tests. --- tests/__init__.py | 0 tests/presto_generation.py | 68 ++++++++++++++++++++++++++++++++++++++ tests/test_pipeline.py | 68 ++------------------------------------ tests/test_rseek.py | 66 +----------------------------------- 4 files changed, 71 insertions(+), 131 deletions(-) create mode 100644 tests/__init__.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/presto_generation.py b/tests/presto_generation.py index 8b13789..8ebae02 100644 --- a/tests/presto_generation.py +++ b/tests/presto_generation.py @@ -1 +1,69 @@ +import os +import numpy as np +from riptide import TimeSeries + +INF_TEMPLATE = """ + Data file name without suffix = {basename:s} + Telescope used = Parkes + Instrument used = Multibeam + Object being observed = Pulsar + J2000 Right Ascension (hh:mm:ss.ssss) = 00:00:01.0000 + J2000 Declination (dd:mm:ss.ssss) = -00:00:01.0000 + Data observed by = Kenji Oba + Epoch of observation (MJD) = 59000.000000 + Barycentered? (1=yes, 0=no) = 1 + Number of bins in the time series = {nsamp:d} + Width of each time series bin (sec) = {tsamp:.12e} + Any breaks in the data? (1=yes, 0=no) = 0 + Type of observation (EM band) = Radio + Beam diameter (arcsec) = 981 + Dispersion measure (cm-3 pc) = {dm:.12f} + Central freq of low channel (Mhz) = 1182.1953125 + Total bandwidth (Mhz) = 400 + Number of channels = 1024 + Channel bandwidth (Mhz) = 0.390625 + Data analyzed by = Space Sheriff Gavan + Any additional notes: + Input filterbank samples have 2 bits. +""" + + +def generate_data_presto( + outdir, + basename, + tobs=128.0, + tsamp=256e-6, + period=1.0, + dm=0.0, + amplitude=20.0, + ducy=0.05, +): + """ + Generate some time series data with a fake signal, and save it in PRESTO + inf/dat format in the specified output directory. + + Parameters + ---------- + outdir : str + Path to the output directory + basename : str + Base file name (not path) under which the .inf and .dat files + will be saved. + **kwargs: self-explanatory + """ + ### IMPORTANT: seed the RNG to get reproducible results ### + np.random.seed(0) + + ts = TimeSeries.generate( + tobs, tsamp, period, amplitude=amplitude, ducy=ducy, stdnoise=1.0 + ) + inf_text = INF_TEMPLATE.format( + basename=basename, nsamp=ts.nsamp, tsamp=tsamp, dm=dm + ) + + inf_path = os.path.join(outdir, f"{basename}.inf") + dat_path = os.path.join(outdir, f"{basename}.dat") + with open(inf_path, "w") as fobj: + fobj.write(inf_text) + ts.data.tofile(dat_path) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 9b5e6d5..3b84d51 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -11,6 +11,8 @@ from riptide.pipeline.pipeline import get_parser, run_program from riptide.pipeline.config_validation import InvalidPipelineConfig, InvalidSearchRange +from .presto_generation import generate_data_presto + # NOTE 1: # pipeline uses multiprocessing, to get proper coverage stats we need: # * A .coveragerc file with the following options: @@ -36,72 +38,6 @@ DATA_TSAMP = 256e-6 -INF_TEMPLATE = """ - Data file name without suffix = {basename:s} - Telescope used = Parkes - Instrument used = Multibeam - Object being observed = Pulsar - J2000 Right Ascension (hh:mm:ss.ssss) = 00:00:01.0000 - J2000 Declination (dd:mm:ss.ssss) = -00:00:01.0000 - Data observed by = Kenji Oba - Epoch of observation (MJD) = 59000.000000 - Barycentered? (1=yes, 0=no) = 1 - Number of bins in the time series = {nsamp:d} - Width of each time series bin (sec) = {tsamp:.12e} - Any breaks in the data? (1=yes, 0=no) = 0 - Type of observation (EM band) = Radio - Beam diameter (arcsec) = 981 - Dispersion measure (cm-3 pc) = {dm:.12f} - Central freq of low channel (Mhz) = 1182.1953125 - Total bandwidth (Mhz) = 400 - Number of channels = 1024 - Channel bandwidth (Mhz) = 0.390625 - Data analyzed by = Space Sheriff Gavan - Any additional notes: - Input filterbank samples have 2 bits. -""" - - -def generate_data_presto( - outdir, - basename, - tobs=128.0, - tsamp=256e-6, - period=1.0, - dm=0.0, - amplitude=20.0, - ducy=0.05, -): - """ - Generate some time series data with a fake signal, and save it in PRESTO - inf/dat format in the specified output directory. - - Parameters - ---------- - outdir : str - Path to the output directory - basename : str - Base file name (not path) under which the .inf and .dat files - will be saved. - **kwargs: self-explanatory - """ - ### IMPORTANT: seed the RNG to get reproducible results ### - np.random.seed(0) - - ts = TimeSeries.generate( - tobs, tsamp, period, amplitude=amplitude, ducy=ducy, stdnoise=1.0 - ) - inf_text = INF_TEMPLATE.format( - basename=basename, nsamp=ts.nsamp, tsamp=tsamp, dm=dm - ) - - inf_path = os.path.join(outdir, f"{basename}.inf") - dat_path = os.path.join(outdir, f"{basename}.dat") - with open(inf_path, "w") as fobj: - fobj.write(inf_text) - ts.data.tofile(dat_path) - - def runner_presto_fakepsr(fname_conf, outdir): # Write test data # NOTE: generate a signal bright enough to get harmonics and thus make sure diff --git a/tests/test_rseek.py b/tests/test_rseek.py index ffd3bf9..d285d9c 100644 --- a/tests/test_rseek.py +++ b/tests/test_rseek.py @@ -5,6 +5,7 @@ from riptide import TimeSeries from riptide.apps.rseek import get_parser, run_program +from .presto_generation import generate_data_presto SIGNAL_PERIOD = 1.0 SIGNAL_FREQ = 1.0 / SIGNAL_PERIOD @@ -17,71 +18,6 @@ Pmin=0.5, Pmax=2.0, bmin=480, bmax=520, smin=7.0, format="presto" ) -INF_TEMPLATE = """ - Data file name without suffix = {basename:s} - Telescope used = Parkes - Instrument used = Multibeam - Object being observed = Pulsar - J2000 Right Ascension (hh:mm:ss.ssss) = 00:00:01.0000 - J2000 Declination (dd:mm:ss.ssss) = -00:00:01.0000 - Data observed by = Kenji Oba - Epoch of observation (MJD) = 59000.000000 - Barycentered? (1=yes, 0=no) = 1 - Number of bins in the time series = {nsamp:d} - Width of each time series bin (sec) = {tsamp:.12e} - Any breaks in the data? (1=yes, 0=no) = 0 - Type of observation (EM band) = Radio - Beam diameter (arcsec) = 981 - Dispersion measure (cm-3 pc) = {dm:.12f} - Central freq of low channel (Mhz) = 1182.1953125 - Total bandwidth (Mhz) = 400 - Number of channels = 1024 - Channel bandwidth (Mhz) = 0.390625 - Data analyzed by = Space Sheriff Gavan - Any additional notes: - Input filterbank samples have 2 bits. -""" - - -def generate_data_presto( - outdir, - basename, - tobs=128.0, - tsamp=256e-6, - period=1.0, - dm=0.0, - amplitude=20.0, - ducy=0.05, -): - """ - Generate some time series data with a fake signal, and save it in PRESTO - inf/dat format in the specified output directory. - - Parameters - ---------- - outdir : str - Path to the output directory - basename : str - Base file name (not path) under which the .inf and .dat files - will be saved. - **kwargs: self-explanatory - """ - ### IMPORTANT: seed the RNG to get reproducible results ### - np.random.seed(0) - - ts = TimeSeries.generate( - tobs, tsamp, period, amplitude=amplitude, ducy=ducy, stdnoise=1.0 - ) - inf_text = INF_TEMPLATE.format( - basename=basename, nsamp=ts.nsamp, tsamp=tsamp, dm=dm - ) - - inf_path = os.path.join(outdir, f"{basename}.inf") - dat_path = os.path.join(outdir, f"{basename}.dat") - with open(inf_path, "w") as fobj: - fobj.write(inf_text) - ts.data.tofile(dat_path) - def dict2args(d): """