diff --git a/MANIFEST.in b/MANIFEST.in index ed72932..23068a1 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -16,6 +16,8 @@ include libstempo/multinest.py include libstempo/plot.py include libstempo/spharmORFbasis.py include libstempo/toasim.py +include libstempo/sandbox.py +include libstempo/tim_file_analyzer.py include libstempo/ecc_vs_nharm.txt include demo/libstempo-demo.ipynb include demo/libstempo-toasim-demo.ipynb diff --git a/README.md b/README.md index 5e44f70..b287af0 100644 --- a/README.md +++ b/README.md @@ -64,3 +64,102 @@ pip install libstempo ## Usage See [Demo Notebook 1](https://github.com/vallis/libstempo/blob/master/demo/libstempo-demo.ipynb) for basic usage and [Demo Notebook 2](https://github.com/vallis/libstempo/blob/master/demo/libstempo-toasim-demo.ipynb) for simulation usage. + +## Sandbox Mode (Crash-Protected) + +libstempo includes a sandbox mode that provides crash isolation and automatic retry capabilities. This is particularly useful when working with problematic pulsars or long-running analyses where tempo2 crashes are common. + +### Basic Usage + +The sandbox provides a drop-in replacement for the standard `tempopulsar` class: + +```python +from libstempo.sandbox import tempopulsar + +# Basic usage - same API as regular tempopulsar +psr = tempopulsar(parfile="J1713.par", timfile="J1713.tim", dofit=False) +residuals = psr.residuals() +design_matrix = psr.designmatrix() +``` + +### Advanced Configuration + +```python +from libstempo.sandbox import tempopulsar, Policy, configure_logging + +# Configure logging for debugging +configure_logging(level="DEBUG", log_file="tempo2.log") + +# Configure retry and timeout policies +policy = Policy( + ctor_retry=5, # Retry constructor 5 times on failure + call_timeout_s=300.0, # 5-minute timeout per RPC call + max_calls_per_worker=1000, # Recycle worker after 1000 calls + max_age_s=3600, # Recycle worker after 1 hour + rss_soft_limit_mb=2048 # Recycle worker if memory exceeds 2GB +) + +psr = tempopulsar(parfile="J1713.par", timfile="J1713.tim", policy=policy) +``` + +### Environment Support + +The sandbox supports different Python environments: + +```python +# Use virtual/conda environment +psr = tempopulsar(parfile="J1713.par", timfile="J1713.tim", env_name="tempo2_intel") + +# Use system Python with Rosetta (macOS) +psr = tempopulsar(parfile="J1713.par", timfile="J1713.tim", env_name="arch") + +# Use explicit Python path +psr = tempopulsar(parfile="J1713.par", timfile="J1713.tim", env_name="python:/path/to/python") +``` + +### Key Benefits + +- **Crash Isolation**: Segfaults in tempo2 only kill the worker process, not your main kernel +- **Automatic Retry**: Built-in retry logic for transient failures +- **Worker Recycling**: Prevents memory leaks and resource accumulation +- **Environment Flexibility**: Support for conda, venv, and Rosetta environments +- **Enhanced Logging**: Comprehensive logging for debugging and monitoring +- **Proactive TOA Handling**: Automatically handles large TOA files to prevent "Too many TOAs" errors + +### Performance + +The sandbox adds ~9x initialization overhead but only ~1.2x overhead for computational operations like `residuals()` and `designmatrix()`. For heavy computations, the overhead becomes negligible relative to the actual work. Use sandbox when stability is critical, direct libstempo when performance is paramount. + +### Bulk Loading + +For processing many pulsars: + +```python +from libstempo.sandbox import load_many, Policy + +pairs = [("J1713.par", "J1713.tim"), ("J1909.par", "J1909.tim"), ...] +policy = Policy(ctor_retry=3, call_timeout_s=120.0) + +ok_by_name, retried_by_name, failed_list = load_many(pairs, policy=policy, parallel=8) + +print(f"Successfully loaded: {len(ok_by_name)}") +print(f"Required retries: {len(retried_by_name)}") +print(f"Failed: {len(failed_list)}") +``` + +### Error Handling + +The sandbox defines specific exception types: + +```python +from libstempo.sandbox import Tempo2Error, Tempo2Crashed, Tempo2Timeout + +try: + psr = tempopulsar(parfile="problematic.par", timfile="problematic.tim") +except Tempo2Crashed: + print("Worker process crashed - likely a segfault") +except Tempo2Timeout: + print("Worker timed out") +except Tempo2Error as e: + print(f"Sandbox error: {e}") +``` diff --git a/libstempo/__init__.py b/libstempo/__init__.py index e672189..519801f 100644 --- a/libstempo/__init__.py +++ b/libstempo/__init__.py @@ -1,6 +1,9 @@ import os from ._find_tempo2 import find_tempo2_runtime +# Import sandbox functionality +from .sandbox import tempopulsar as sandbox_tempopulsar, Policy, configure_logging # noqa: F401 +from .tim_file_analyzer import TimFileAnalyzer # noqa: F401 # check to see if TEMPO2 environment variable is set TEMPO2_RUNTIME = os.getenv("TEMPO2") diff --git a/libstempo/sandbox.py b/libstempo/sandbox.py new file mode 100644 index 0000000..18e621b --- /dev/null +++ b/libstempo/sandbox.py @@ -0,0 +1,1863 @@ +# flake8: noqa: E501 +""" +Author: Rutger van Haasteren -- rutger@vhaasteren.com +Date: 2025-10-10 + +Process sandbox for libstempo/tempo2 that keeps each pulsar in its own clean +subprocess. A segfault in tempo2/libstempo only kills the worker, not your kernel. + +Usage (drop-in): + from sandbox import tempopulsar + psr = tempopulsar(parfile="J1713.par", timfile="J1713.tim", dofit=False) + r = psr.residuals() + +Advanced with logging: + from sandbox import tempopulsar, configure_logging, Policy + configure_logging(level="DEBUG", log_file="tempo2.log") + policy = Policy(ctor_retry=5, call_timeout_s=300.0) + psr = tempopulsar(parfile="J1713.par", timfile="J1713.tim", policy=policy) + +With specific environment: + psr = tempopulsar(parfile="J1713.par", timfile="J1713.tim", env_name="myenv") + # or for conda: env_name="mycondaenv" + # or explicit path: env_name="python:/path/to/python" + +With persistent workers (no recycling/timeouts): + policy = Policy( + call_timeout_s=None, # No RPC timeouts + max_calls_per_worker=None, # Never recycle by call count + max_age_s=None, # Never recycle by age + rss_soft_limit_mb=None # Never recycle by memory + ) + psr = tempopulsar(parfile="J1713.par", timfile="J1713.tim", policy=policy) + +Advanced: + from sandbox import load_many, Policy + ok, retried, failed = load_many([("J1713.par","J1713.tim"), ...], policy=Policy()) + +Environment selection (Apple Silicon + Rosetta etc.): + psr = tempopulsar(..., env_name="tempo2_intel") # conda env + psr = tempopulsar(..., env_name="myvenv") # venv (~/.venvs/myvenv, etc.) + psr = tempopulsar(..., env_name="arch") # system python via Rosetta (arch -x86_64) + psr = tempopulsar(..., env_name="python:/abs/python") # explicit Python path + +You can force Rosetta prefix via env var: + TEMPO2_SANDBOX_WORKER_ARCH_PREFIX="arch -x86_64" + +Logging: + The sandbox includes comprehensive loguru logging for debugging and monitoring. + Use configure_logging() to set up logging levels and outputs. Logs include: + - Worker process lifecycle (creation, recycling, termination) + - RPC call details and timing + - Constructor retry attempts and failures + - Memory usage and recycling decisions + - Error details and recovery attempts + +Robustness: + The sandbox redirects stdout to stderr in worker processes before importing libstempo, + preventing interference with the JSON-RPC protocol. This ensures reliable + communication even when libstempo or other libraries print diagnostic messages. + The redirection works at the OS file descriptor level to catch output from C libraries. +""" + +from __future__ import annotations + +import base64 +import contextlib +import dataclasses +import json +import os +import pickle +import platform +import select +import shutil +import signal +import subprocess +import sys +import time +import traceback +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Tuple +from collections.abc import Mapping + +# Import TimFileAnalyzer for proactive TOA counting +from .tim_file_analyzer import TimFileAnalyzer + +# Standard logging +import logging + +logger = logging.getLogger(__name__) + +# ---------------------------- Public Exceptions ---------------------------- # + + +class Tempo2Error(Exception): + """Base class for sandbox errors.""" + + +class Tempo2Crashed(Tempo2Error): + """The worker process crashed or died unexpectedly (likely a segfault).""" + + +class Tempo2Timeout(Tempo2Error): + """The worker did not reply in time; it was terminated.""" + + +class Tempo2ProtocolError(Tempo2Error): + """Malformed RPC request/response or other IPC failure.""" + + +class Tempo2ConstructorFailed(Tempo2Error): + """Constructor failed even after retries.""" + + +# ------------------------------- Policy knobs ----------------------------- # +@dataclass(frozen=True) +class Policy: + """Configuration policy for sandbox worker behavior and lifecycle management. + + Controls retry behavior, timeouts, and worker recycling policies. + """ + + # Constructor protection + ctor_retry: int = 5 # number of extra tries after the first + ctor_backoff: float = 0.75 # seconds between ctor retries + preload_residuals: bool = False # call residuals() once after ctor + preload_designmatrix: bool = False # call designmatrix() once after ctor + preload_toas: bool = False # call toas() once after ctor + preload_fit: bool = False # call fit() once after ctor + + # RPC protection + call_timeout_s: Optional[float] = None # per-call timeout (seconds), None = no timeout + kill_grace_s: float = 2.0 # after timeout, wait before SIGKILL + + # Recycling / hygiene + max_calls_per_worker: Optional[int] = None # recycle after this many good calls, None = never recycle by calls + max_age_s: Optional[float] = None # recycle after this many seconds, None = never recycle by age + rss_soft_limit_mb: Optional[int] = None # if provided, recycle when beaten + + # Proactive TOA handling for large files + auto_nobs_retry: bool = True # automatically add nobs parameter for large TOA files + nobs_threshold: int = 10000 # add nobs parameter if TOA count exceeds this threshold + nobs_safety_margin: float = 1.1 # multiplier for nobs parameter (e.g., 1.1 = 10% more than actual count) + + # Logging / stderr capture + stderr_ring_max_lines: int = 20000 + stderr_log_file: Optional[str] = None + include_stderr_tail_in_errors: int = 200 # 0 disables tail inclusion + + +# -------------------------- Wire serialization helpers --------------------- # + +# We send JSON-RPC 2.0 frames. To avoid JSON-encoding numpy arrays and +# cross-arch issues, params/result travel as base64-encoded cloudpickle blobs. + +try: + import cloudpickle as _cp # best-effort; falls back to pickle if missing +except Exception: + _cp = pickle + + +def _b64_dumps_py(obj: Any) -> str: + """Serialize Python object to base64-encoded string using cloudpickle.""" + return base64.b64encode(_cp.dumps(obj)).decode("ascii") + + +def _b64_loads_py(s: str) -> Any: + """Deserialize base64-encoded string to Python object using cloudpickle.""" + if hasattr(s, 'encode'): + s_str = s + else: + s_str = str(s) + return _cp.loads(base64.b64decode(s_str.encode("ascii"))) + + +def _format_exc_tuple() -> Tuple[str, str, str]: + """Format current exception info as tuple of (type_name, message, traceback).""" + et, ev, tb = sys.exc_info() + name = et.__name__ if et else "Exception" + return (name, str(ev), "".join(traceback.format_exception(et, ev, tb))) + + +def _current_rss_mb_portable() -> Optional[int]: + """Get current process RSS memory usage in MB, portable across platforms.""" + try: + if sys.platform.startswith("linux"): + with open("/proc/self/statm") as f: + pages = int(f.read().split()[1]) + rss = pages * (os.sysconf("SC_PAGE_SIZE") // 1024 // 1024) + return rss + except Exception: + pass + try: + import psutil # type: ignore + + return int(psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)) + except Exception: + return None + + +# ----------------------------- Worker (stdio) ------------------------------ # + + +def _worker_stdio_main() -> None: + """ + Runs inside the worker interpreter (possibly Rosetta x86_64). + Protocol: + 1) Immediately print a single 'hello' JSON line with environment info. + 2) Then serve JSON-RPC 2.0 requests line-by-line on stdin/stdout. + Methods: ctor, get, set, call, del, rss, bye + Each request's 'params_b64' is a pickled dict of parameters. + Each response uses 'result_b64' for Python results, or 'error'. + + To prevent interference with the JSON-RPC protocol, stdout is redirected to stderr + before importing libstempo. This ensures clean communication even when libstempo + prints diagnostic messages. The redirection works at the OS file descriptor level. + """ + import os as _os_for_fds + + # Check if redirection was already set up by the subprocess command + _proto_out = None # type: ignore + _env_fd = os.environ.get("TEMPO2_SANDBOX_PROTO_FD") + if _env_fd is not None: + try: + _proto_fd = int(_env_fd) + _proto_out = _os_for_fds.fdopen(_proto_fd, "w", buffering=1) + except Exception: + _proto_out = None + finally: + # Remove the hint to avoid leaking to children + with contextlib.suppress(Exception): + os.environ.pop("TEMPO2_SANDBOX_PROTO_FD", None) + if _proto_out is None: + # Fallback: perform redirection here if not done in command + _proto_fd = _os_for_fds.dup(1) # save original stdout FD for protocol + _os_for_fds.dup2(2, 1) # route any C/printf stdout to stderr + sys.stdout = sys.stderr # route Python-level prints to stderr + _proto_out = _os_for_fds.fdopen(_proto_fd, "w", buffering=1) + + # Step 1: hello handshake + hello = { + "hello": { + "python": sys.version.split()[0], + "executable": sys.executable, + "machine": platform.machine(), + "platform": platform.platform(), + "has_libstempo": False, + "tempo2_version": None, + "proto_version": "1.2", + "capabilities": {"get_kind": True, "dir": True, "setitem": True, "get_slice": True, "path_access": True}, + } + } + try: + try: + from libstempo import tempopulsar as _lib_tempopulsar # noqa + import numpy # noqa + + hello["hello"]["has_libstempo"] = True + # best-effort tempo2 version probe + try: + from libstempo import tempo2 # type: ignore + + hello["hello"]["tempo2_version"] = getattr(tempo2, "TEMPO2_VERSION", None) + except Exception: + pass + except Exception: + pass + finally: + _proto_out.write(json.dumps(hello) + "\n") + _proto_out.flush() + + # If libstempo failed to import at hello, try once more here to return clean errors + try: + from libstempo import tempopulsar as _lib_tempopulsar # noqa + import numpy # noqa + except Exception: + # Keep serving, but report on first request + _lib_tempopulsar: Optional[Any] = None + + obj = None + + def _write_response(resp: Dict[str, Any]) -> None: + """Write JSON response to stdout and flush.""" + _proto_out.write(json.dumps(resp) + "\n") + _proto_out.flush() + + # JSON-RPC loop + for line in sys.stdin: + line = line.strip() + if not line: + continue + try: + req = json.loads(line) + except Exception: + _write_response( + { + "jsonrpc": "2.0", + "id": None, + "error": {"code": -32700, "message": "parse error"}, + } + ) + continue + + rid = req.get("id", None) + method = req.get("method", "") + params_b64 = req.get("params_b64", None) + + # Decode params dict if present + params = {} + if params_b64 is not None: + try: + params = _b64_loads_py(params_b64) + if not isinstance(params, dict): + raise TypeError("params_b64 must decode to dict") + except Exception: + et, ev, tb = _format_exc_tuple() + _write_response( + { + "jsonrpc": "2.0", + "id": rid, + "error": { + "code": -32602, + "message": f"invalid params: {ev}", + "data": tb, + }, + } + ) + continue + + # Handle methods + try: + if method == "bye": + _write_response({"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py("bye")}) + return + + if method == "rss": + rss = _current_rss_mb_portable() + _write_response({"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py(rss)}) + continue + + if method == "ctor": + if _lib_tempopulsar is None: + raise ImportError("libstempo not available in worker") + + obj = _lib_tempopulsar(**params["kwargs"]) + if params.get("preload_residuals", True): + _ = obj.residuals(updatebats=True, formresiduals=True) + + _write_response( + { + "jsonrpc": "2.0", + "id": rid, + "result_b64": _b64_dumps_py("constructed"), + } + ) + continue + + if obj is None: + raise RuntimeError("object not constructed") + + if method == "get": + name = params["name"] + # Support dotted path and mapping access for parameters (e.g., 'RAJ.val') + parts = str(name).split(".") if isinstance(name, str) else [name] + cur = obj + missing = False + for idx, part in enumerate(parts): + # First hop supports attribute or mapping access + if idx == 0: + try: + if hasattr(cur, part): + cur = getattr(cur, part) + elif isinstance(cur, Mapping) or hasattr(cur, "__getitem__"): + cur = cur[part] + else: + raise AttributeError + except Exception: + missing = True + break + else: + try: + cur = getattr(cur, part) + except Exception: + missing = True + break + + if missing: + _write_response( + {"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py({"kind": "missing", "value": None})} + ) + continue + + # cur is the resolved object/value + if callable(cur): + _write_response( + {"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py({"kind": "callable", "value": None})} + ) + continue + + try: + import numpy as _np2 + + if isinstance(cur, _np2.ndarray): + cur = cur.copy(order="C") + elif isinstance(cur, _np2.generic): + cur = cur.item() + except Exception: + pass + + # Safely serialize value; some libstempo/Boost.Python objects are not picklable + try: + _ = _b64_dumps_py({"kind": "value", "value": cur}) + result_payload = {"kind": "value", "value": cur} + except Exception: + # Best-effort conversion for libstempo parameter-like objects + safe_value = None + try: + # Detect param-like structures (e.g., RAJ/DEC) and extract primitives + has_val = hasattr(cur, "val") or hasattr(cur, "_val") + has_err = hasattr(cur, "err") or hasattr(cur, "_err") + has_fit = hasattr(cur, "fit") or hasattr(cur, "fitFlag") or hasattr(cur, "_fitFlag") + if has_val or has_err or has_fit: + val = None + err = None + fit = None + with contextlib.suppress(Exception): + v = getattr(cur, "val", getattr(cur, "_val", None)) + # Convert numpy scalars to Python + try: + import numpy as _np2 + + if isinstance(v, _np2.generic): + v = v.item() + except Exception: + pass + val = v + with contextlib.suppress(Exception): + e = getattr(cur, "err", getattr(cur, "_err", None)) + try: + import numpy as _np2 + + if isinstance(e, _np2.generic): + e = e.item() + except Exception: + pass + err = e + with contextlib.suppress(Exception): + f = getattr(cur, "fit", None) + if f is None: + f = getattr(cur, "fitFlag", getattr(cur, "_fitFlag", None)) + # Normalize to bool when possible + if isinstance(f, (int, bool)): + fit = bool(f) + else: + fit = f + name_guess = None + with contextlib.suppress(Exception): + name_guess = getattr(cur, "name", None) or getattr(cur, "label", None) + safe_value = { + "__libstempo_param__": True, + "name": name_guess, + "val": val, + "err": err, + "fit": fit, + } + else: + # Fallback to repr string if completely opaque + safe_value = {"__repr__": repr(cur)} + except Exception: + safe_value = {"__repr__": repr(cur)} + + result_payload = {"kind": "value", "value": safe_value} + + _write_response( + { + "jsonrpc": "2.0", + "id": rid, + "result_b64": _b64_dumps_py(result_payload), + } + ) + continue + + if method == "set": + name, value = params["name"], params["value"] + # Support dotted path and mapping access for parameters (e.g., 'RAJ.val') + parts = str(name).split(".") if isinstance(name, str) else [name] + cur = obj + missing = False + # Traverse to parent of target + for idx, part in enumerate(parts[:-1]): + try: + if idx == 0: + if hasattr(cur, part): + cur = getattr(cur, part) + elif isinstance(cur, Mapping) or hasattr(cur, "__getitem__"): + cur = cur[part] + else: + raise AttributeError + else: + cur = getattr(cur, part) + except Exception: + missing = True + break + if missing: + _write_response( + { + "jsonrpc": "2.0", + "id": rid, + "error": { + "code": -32000, + "message": f"AttributeError: cannot resolve path for set: {name}", + "data": "", + }, + } + ) + continue + target = parts[-1] + try: + setattr(cur, target, value) + except Exception: + et, ev, tb = _format_exc_tuple() + _write_response( + {"jsonrpc": "2.0", "id": rid, "error": {"code": -32000, "message": f"{et}: {ev}", "data": tb}} + ) + continue + _write_response({"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py(None)}) + continue + + if method == "setitem": + # Set slice(s) on numpy array attributes like stoas, toaerrs + name = params["name"] + index = params["index"] + value = params["value"] + try: + arr = getattr(obj, name) + try: + import numpy as _np2 + + if isinstance(value, _np2.ndarray) and not _np2.can_cast( + value.dtype, arr.dtype, casting="safe" + ): + value = value.astype(arr.dtype, copy=False) + except Exception: + pass + arr[index] = value + _write_response({"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py(None)}) + except Exception: + et, ev, tb = _format_exc_tuple() + _write_response( + {"jsonrpc": "2.0", "id": rid, "error": {"code": -32000, "message": f"{et}: {ev}", "data": tb}} + ) + continue + + if method == "get_slice": + name = params["name"] + index = params["index"] + try: + arr = getattr(obj, name) + import numpy as _np2 + + out = arr[index] + if isinstance(out, _np2.ndarray): + out = out.copy(order="C") + elif isinstance(out, _np2.generic): + out = out.item() + _write_response({"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py(out)}) + except Exception: + et, ev, tb = _format_exc_tuple() + _write_response( + {"jsonrpc": "2.0", "id": rid, "error": {"code": -32000, "message": f"{et}: {ev}", "data": tb}} + ) + continue + + if method == "call": + name = params["name"] + args = tuple(params.get("args", ())) + kwargs = dict(params.get("kwargs", {})) + meth = getattr(obj, name) + out = meth(*args, **kwargs) + try: + import numpy as _np2 + + if isinstance(out, _np2.ndarray): + # Always copy numpy arrays to avoid C++ object references + out = out.copy(order="C") + elif isinstance(out, _np2.generic): + out = out.item() + except Exception: + pass + _write_response({"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py(out)}) + continue + + if method == "del": + try: + del obj + except Exception: + pass + obj = None + _write_response({"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py(None)}) + continue + + if method == "dir": + names = [] + for n in dir(obj): + if not n.startswith("_"): + names.append(n) + names.sort() + _write_response({"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py(names)}) + continue + + _write_response( + { + "jsonrpc": "2.0", + "id": rid, + "error": {"code": -32601, "message": f"method not found: {method}"}, + } + ) + except Exception: + et, ev, tb = _format_exc_tuple() + _write_response( + { + "jsonrpc": "2.0", + "id": rid, + "error": {"code": -32000, "message": f"{et}: {ev}", "data": tb}, + } + ) + + +# ------------------------------ Subprocess client -------------------------- # + + +class _WorkerProc: + """ + JSON-RPC over stdio subprocess. + Launches the worker in the requested environment (conda/venv/arch/system). + """ + + def __init__(self, policy: Policy, cmd: List[str], require_x86_64: bool = False): + self.policy = policy + self.cmd = cmd + self.proc: Optional[subprocess.Popen] = None + self._id = 0 + logger.info(f"Creating worker process with command: {' '.join(cmd)}") + logger.info(f"Require x86_64 architecture: {require_x86_64}") + self._start(require_x86_64=require_x86_64) + + # ---------- process management ---------- + + def _start(self, require_x86_64: bool = False): + logger.debug("Starting worker subprocess...") + self._hard_kill() # just in case + + # Ensure unbuffered text I/O + env = os.environ.copy() + env.setdefault("PYTHONUNBUFFERED", "1") + + logger.debug(f"Launching subprocess with environment: PYTHONUNBUFFERED={env.get('PYTHONUNBUFFERED')}") + logger.debug(f"Subprocess working directory: {os.getcwd()}") + creationflags = 0 + if os.name == "nt": + try: + creationflags = subprocess.CREATE_NEW_PROCESS_GROUP # type: ignore[attr-defined] + except Exception: + creationflags = 0 + self.proc = subprocess.Popen( + self.cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, # line buffered + cwd=os.getcwd(), # Explicitly set working directory + env=env, + close_fds=True, + start_new_session=True, + creationflags=creationflags, + ) + + logger.debug(f"Worker process started with PID: {self.proc.pid}") + + # Start background stderr drain to capture logs AND output to real-time stderr + import threading + import collections + + self._log_buf = collections.deque(maxlen=self.policy.stderr_ring_max_lines) + log_file = None + if self.policy.stderr_log_file: + try: + log_file = open(self.policy.stderr_log_file, "a", buffering=1, encoding="utf-8") + except Exception: + log_file = None + + def _drain_stderr(pipe, sink_deque, sink_file): + try: + for line in iter(pipe.readline, ""): + line = line.rstrip("\n") + sink_deque.append(line) + + # Write to real stderr for real-time output (native-like behavior) + print(line, file=sys.stderr, flush=True) + + if sink_file: + try: + sink_file.write(line + "\n") + except Exception: + pass + logger.debug("[tempo2-stderr] %s", line) + finally: + with contextlib.suppress(Exception): + pipe.close() + if sink_file: + with contextlib.suppress(Exception): + sink_file.flush() + sink_file.close() + + if self.proc.stderr is not None: + self._stderr_thread = threading.Thread( + target=_drain_stderr, args=(self.proc.stderr, self._log_buf, log_file), daemon=True + ) + self._stderr_thread.start() + + # Hello handshake (one line of JSON) + logger.debug("Waiting for worker hello handshake...") + hello = self._readline_with_timeout(self.policy.call_timeout_s) + if hello is None: + if self.policy.call_timeout_s is None: + logger.error("Worker did not send hello - worker disconnected") + self._hard_kill() + raise Tempo2Crashed("worker did not send hello - worker disconnected") + else: + logger.error("Worker did not send hello in time") + self._hard_kill() + raise Tempo2Timeout("worker did not send hello in time") + + try: + hello_obj = json.loads(hello) + except Exception as e: + logger.error(f"Failed to parse worker hello: {e}") + self._hard_kill() + raise Tempo2ProtocolError(f"malformed hello: {hello!r}") + + info = hello_obj.get("hello", {}) + logger.info(f"Worker hello received: {info}") + self._proto_version = info.get("proto_version", "1.0") + + if require_x86_64: + if str(info.get("machine", "")).lower() != "x86_64": + logger.error(f"Architecture mismatch: worker is {info.get('machine')}, but x86_64 required") + self._hard_kill() + raise Tempo2Error(f"worker arch is {info.get('machine')}, but x86_64 is required for quad precision") + + if not info.get("has_libstempo", False): + logger.error("libstempo not available in worker environment") + # Keep the worker up; subsequent ctor will return a clean error, + # but we can already warn here to fail fast. + self._hard_kill() + raise Tempo2Error( + "libstempo is not importable inside the selected environment. " + f"Worker executable: {info.get('executable')}" + ) + + self.birth = time.time() + self.calls_ok = 0 + logger.info(f"Worker ready and initialized (PID: {self.proc.pid})") + + def _readline_with_timeout(self, timeout: Optional[float]) -> Optional[str]: + if self.proc is None or self.proc.stdout is None: + return None + + if timeout is None: + # No timeout - wait indefinitely + while True: + rlist, _, _ = select.select([self.proc.stdout], [], []) + if rlist: + line = self.proc.stdout.readline() + if not line: # EOF + return None + return line.rstrip("\n") + else: + # With timeout + end = time.time() + timeout + while time.time() < end: + rlist, _, _ = select.select([self.proc.stdout], [], [], max(0.01, end - time.time())) + if rlist: + line = self.proc.stdout.readline() + if not line: # EOF + return None + return line.rstrip("\n") + return None + + def _hard_kill(self): + if self.proc and self.proc.poll() is None: + logger.debug(f"Hard killing worker process (PID: {self.proc.pid})") + try: + if os.name == "nt": + with contextlib.suppress(Exception): + self.proc.terminate() + else: + os.killpg(self.proc.pid, signal.SIGTERM) + except Exception as e: + logger.warning(f"Failed to terminate process group: {e}; falling back to terminate()") + with contextlib.suppress(Exception): + self.proc.terminate() + t0 = time.time() + while self.proc.poll() is None and (time.time() - t0) < self.policy.kill_grace_s: + time.sleep(0.01) + if self.proc.poll() is None: + logger.debug(f"Sending SIGKILL to worker process (PID: {self.proc.pid})") + with contextlib.suppress(Exception): + try: + if os.name == "nt": + self.proc.kill() + else: + os.killpg(self.proc.pid, signal.SIGKILL) + except Exception: + os.kill(self.proc.pid, signal.SIGKILL) + self.proc = None + + def close(self): + logger.debug("Closing worker process...") + if self.proc and self.proc.poll() is None: + try: + logger.debug("Sending bye RPC to worker") + self._send_rpc("bye", {}) + # ignore response; we're closing anyway + except Exception as e: + logger.debug(f"Bye RPC failed (expected): {e}") + pass + self._hard_kill() + logger.debug("Worker process closed") + + def __del__(self): + with contextlib.suppress(Exception): + self.close() + + # ---------- JSON-RPC helpers ---------- + + def _send_rpc(self, method: str, params: Dict[str, Any], timeout: Optional[float] = None) -> Any: + if self.proc is None or self.proc.stdin is None or self.proc.stdout is None: + logger.error("Worker not running, cannot send RPC") + raise Tempo2Crashed("worker not running") + + self._id += 1 + rid = self._id + # Only log debug for non-get methods to reduce noise + if method != "get": + logger.debug(f"Sending RPC {method} (id: {rid})") + + frame = { + "jsonrpc": "2.0", + "id": rid, + "method": method, + "params_b64": _b64_dumps_py(params), + } + line = json.dumps(frame) + "\n" + + # Protect frames from interleaving + import threading + + if not hasattr(self, "_rpc_lock"): + self._rpc_lock = threading.Lock() + try: + with self._rpc_lock: + self.proc.stdin.write(line) + self.proc.stdin.flush() + except Exception as e: + logger.error(f"Failed to send RPC {method}: {e}") + self._hard_kill() + raise Tempo2Crashed(f"send failed: {e!r}") + + # Wait for response + t = self.policy.call_timeout_s if timeout is None else timeout + if t is None: + logger.debug(f"Waiting for RPC {method} response (no timeout)") + else: + logger.debug(f"Waiting for RPC {method} response (timeout: {t}s)") + resp_line = self._readline_with_timeout(t) + if resp_line is None: + if t is None: + logger.error(f"RPC {method} failed - worker disconnected") + self._hard_kill() + raise Tempo2Crashed(f"RPC '{method}' failed - worker disconnected") + else: + logger.error(f"RPC {method} timed out after {t}s") + self._hard_kill() + raise Tempo2Timeout(f"RPC '{method}' timed out") + + try: + resp = json.loads(resp_line) + except Exception as e: + logger.error(f"Failed to parse RPC {method} response: {e}") + self._hard_kill() + raise Tempo2ProtocolError(f"malformed response: {resp_line!r}") + + if resp.get("id") != rid: + logger.error(f"RPC {method} id mismatch: expected {rid}, got {resp.get('id')}") + self._hard_kill() + raise Tempo2ProtocolError(f"mismatched id in response: {resp.get('id')} vs {rid}") + + if "error" in resp and resp["error"] is not None: + err = resp["error"] + msg = err.get("message", "error") + data = err.get("data", "") + logger.error(f"RPC {method} failed: {msg}") + tail = "" + if getattr(self, "_log_buf", None) and (self.policy.include_stderr_tail_in_errors or 0) > 0: + try: + tail_lines = list(self._log_buf)[-self.policy.include_stderr_tail_in_errors :] + if tail_lines: + blob = "\n".join(tail_lines) + max_bytes = 16384 + if len(blob) > max_bytes: + blob = blob[-max_bytes:] + tail = "\n--- tempo2 stderr (tail) ---\n" + blob + except Exception: + tail = "" + raise Tempo2Error(f"{msg}\n{data}{tail}") + + # Only log debug for non-get methods to reduce noise + if method != "get": + logger.debug(f"RPC {method} completed successfully") + result_b64 = resp.get("result_b64", None) + return _b64_loads_py(result_b64) if result_b64 is not None else None + + # Public RPCs + def ctor(self, kwargs: Dict[str, Any], preload_residuals: bool): + logger.info(f"Constructing tempopulsar with kwargs: {kwargs}") + logger.info(f"Preload residuals: {preload_residuals}") + return self._send_rpc("ctor", {"kwargs": kwargs, "preload_residuals": preload_residuals}) + + def get(self, name: str): + logger.debug(f"Getting attribute: {name}") + return self._send_rpc("get", {"name": name}) + + def get_kind(self, name: str): + """Return (kind, value) where kind in {"value","callable","missing"}. + For legacy workers without get-kind support, assumes raw value => ("value", value). + """ + resp = self._send_rpc("get", {"name": name}) + if isinstance(resp, dict) and "kind" in resp: + return (resp.get("kind"), resp.get("value")) + return ("value", resp) + + def dir(self): + """Return list of public attribute names.""" + return self._send_rpc("dir", {}) + + def set(self, name: str, value: Any): + logger.debug(f"Setting attribute: {name}") + return self._send_rpc("set", {"name": name, "value": value}) + + def call(self, name: str, args=(), kwargs=None): + logger.debug(f"Calling method: {name} with args={args}, kwargs={kwargs}") + return self._send_rpc("call", {"name": name, "args": tuple(args), "kwargs": dict(kwargs or {})}) + + def setitem(self, name: str, index, value: Any): + logger.debug(f"Setting array slice: {name}[{index}] = ") + return self._send_rpc("setitem", {"name": name, "index": index, "value": value}) + + def get_slice(self, name: str, index): + logger.debug(f"Getting array slice: {name}[{index}]") + return self._send_rpc("get_slice", {"name": name, "index": index}) + + def rss(self) -> Optional[int]: + try: + logger.debug("Getting worker RSS memory usage") + return self._send_rpc("rss", {}) + except Exception as e: + logger.warning(f"Failed to get RSS: {e}") + return None + + def logs(self, tail: int = 500) -> str: + try: + return "\n".join(list(self._log_buf)[-max(0, tail) :]) + except Exception: + return "" + + +# ------------------------- Command resolution (env_name) -------------------- # + + +def _detect_environment_type(env_name: str) -> str: + """ + Return "conda", "venv", "arch", "python", or "unknown". + """ + if env_name.startswith("python:"): + return "python" + + # conda family + for tool in ("conda", "mamba", "micromamba"): + try: + r = subprocess.run( + [tool, "run", "-n", env_name, "python", "--version"], + capture_output=True, + text=True, + timeout=5, + ) + if r.returncode == 0: + return "conda:" + tool + except Exception: + pass + + # common venv locations + venv_paths = [ + Path.home() / ".venvs" / env_name / "bin" / "python", + Path.home() / "venvs" / env_name / "bin" / "python", + Path.home() / ".virtualenvs" / env_name / "bin" / "python", + Path.cwd() / env_name / "bin" / "python", + Path.cwd() / ".venv" / "bin" / "python", # only if env_name == '.venv' + # Additional common locations for containers/dev environments + Path("/opt/venvs") / env_name / "bin" / "python", + Path("/opt/virtualenvs") / env_name / "bin" / "python", + Path("/usr/local/venvs") / env_name / "bin" / "python", + Path("/home") / "venvs" / env_name / "bin" / "python", + # Try to find any python executable with the env name in the path + Path(f"/opt/venvs/{env_name}/bin/python"), + Path(f"/opt/virtualenvs/{env_name}/bin/python"), + ] + for p in venv_paths: + if p.exists(): + return "venv" + + if env_name in ("arch", "rosetta", "system"): + return "arch" + + return "unknown" + + +def _find_venv_python_path(env_name: str) -> Optional[str]: + venv_paths = [ + Path.home() / ".venvs" / env_name / "bin" / "python", + Path.home() / "venvs" / env_name / "bin" / "python", + Path.home() / ".virtualenvs" / env_name / "bin" / "python", + Path.cwd() / env_name / "bin" / "python", + Path.cwd() / ".venv" / "bin" / "python", + # Additional common locations for containers/dev environments + Path("/opt/venvs") / env_name / "bin" / "python", + Path("/opt/virtualenvs") / env_name / "bin" / "python", + Path("/usr/local/venvs") / env_name / "bin" / "python", + Path("/home") / "venvs" / env_name / "bin" / "python", + # Try to find any python executable with the env name in the path + Path(f"/opt/venvs/{env_name}/bin/python"), + Path(f"/opt/virtualenvs/{env_name}/bin/python"), + ] + for p in venv_paths: + if p.exists(): + return str(p) + return None + + +def _resolve_worker_cmd(env_name: Optional[str]) -> Tuple[List[str], bool]: + """ + Build the subprocess command to run the worker and whether we require x86_64. + Returns (cmd, require_x86_64) + """ + + # Base invocation that runs this file in worker mode. + # Find the src directory dynamically + current_file = Path(__file__).resolve() + src_dir = current_file.parent.parent # Go up from libstempo/sandbox.py to src/ + src_path = str(src_dir) + + def python_to_worker_cmd(python_exe: str) -> List[str]: + """Build command to run worker with given Python executable.""" + return [ + python_exe, + "-c", + ( + f"import os, sys; " + f"os.environ['TEMPO2_SANDBOX_PROTO_FD'] = str(os.dup(1)); " + f"os.dup2(2, 1); " + f"sys.stdout = sys.stderr; " + f"sys.path.insert(0, '{src_path}'); " + f"import libstempo.sandbox as m; " + f"m._worker_stdio_main()" + ), + ] + + arch_prefix_env = os.environ.get("TEMPO2_SANDBOX_WORKER_ARCH_PREFIX", "").strip() + require_x86_64 = False + + # No env_name -> use current python (no Rosetta) + if env_name is None: + py = sys.executable + return (python_to_worker_cmd(py), False) + + # Explicit python path + if env_name.startswith("python:"): + py = env_name.split(":", 1)[1] + return (python_to_worker_cmd(py), False) + + etype = _detect_environment_type(env_name) + + # conda/mamba/micromamba + if etype.startswith("conda:"): + tool = etype.split(":", 1)[1] + cmd = [tool, "run", "-n", env_name] + python_to_worker_cmd("python") + # Choosing to require x86_64 only if user *explicitly* asks via arch prefix or env_name == "arch" + require_x86_64 = "arch" in env_name.lower() + if arch_prefix_env: + cmd = arch_prefix_env.split() + cmd + require_x86_64 = True + return (cmd, require_x86_64) + + # venv + if etype == "venv": + py = _find_venv_python_path(env_name) + if not py: + raise Tempo2Error(f"virtualenv '{env_name}' not found in common locations") + cmd = python_to_worker_cmd(py) + if arch_prefix_env: + cmd = arch_prefix_env.split() + cmd + require_x86_64 = True + return (cmd, require_x86_64) + + # system Rosetta + if etype == "arch": + # try system python (python3 or python) + py = shutil.which("python3") or shutil.which("python") + if not py: + raise Tempo2Error("could not find system python for arch mode") + arch = arch_prefix_env.split() if arch_prefix_env else ["arch", "-x86_64"] + require_x86_64 = True + return (arch + python_to_worker_cmd(py), require_x86_64) + + raise Tempo2Error( + f"Environment '{env_name}' not found. " "Use a conda env name, a venv name, 'arch', or 'python:/abs/python'." + ) + + +# ------------------------------ Public proxy ------------------------------- # + + +@dataclasses.dataclass +class _State: + """Internal state tracking for tempopulsar proxy instances.""" + + created_at: float + calls_ok: int + # State cache for crash recovery + param_cache: Dict[str, Dict[str, Any]] = dataclasses.field( + default_factory=dict + ) # {'RAJ': {'val': 5.016, 'fit': True}, ...} + array_cache: Dict[str, Any] = dataclasses.field( + default_factory=dict + ) # {'stoas': modified_array, 'toaerrs': modified_array} + # Crash recovery statistics + crash_count: int = 0 + last_crash_at: Optional[float] = None + + +class tempopulsar: + """ + Proxy for libstempo.tempopulsar living inside an isolated subprocess. + + This class provides a drop-in replacement for libstempo.tempopulsar that runs + in a separate process to prevent crashes from affecting the main kernel. + All constructor arguments are forwarded to libstempo.tempopulsar unchanged. + + The proxy automatically handles: + - Worker process lifecycle management + - Automatic retry on failures + - Worker recycling based on age, call count, or memory usage + - JSON-RPC communication over stdio + + Args: + env_name: Environment name (conda env or venv name, 'arch', or 'python:/abs/python'). + If None (default), uses the current Python environment. + policy: Optional Policy instance to configure worker behavior + **kwargs: Additional arguments passed to libstempo.tempopulsar + + Example: + >>> psr = tempopulsar(parfile="J1713.par", timfile="J1713.tim", dofit=False) + >>> residuals = psr.residuals() + >>> design_matrix = psr.designmatrix() + """ + + __slots__ = ( + "_policy", + "_wp", + "_state", + "_ctor_kwargs", + "_env_name", + "_require_x86", + ) + + def __init__(self, env_name: Optional[str] = None, **kwargs): + policy = kwargs.pop("policy", None) + self._policy: Policy = policy if isinstance(policy, Policy) else Policy() + self._env_name = env_name + self._ctor_kwargs = dict(kwargs) + self._wp: Optional[_WorkerProc] = None + self._state = _State(created_at=time.time(), calls_ok=0) + self._require_x86 = False + + logger.info(f"Creating tempopulsar with env_name='{env_name}', kwargs={self._ctor_kwargs}") + logger.info(f"Using policy: ctor_retry={self._policy.ctor_retry}, ctor_backoff={self._policy.ctor_backoff}s") + self._construct_with_retries() + + # --------------- construction / reconstruction with retries --------------- # + + def _construct_with_retries(self): + logger.info(f"Starting construction with {self._policy.ctor_retry + 1} total attempts") + + # Fast-fail on missing input files to avoid noisy retries + try: + parfile = self._ctor_kwargs.get("parfile") + timfile = self._ctor_kwargs.get("timfile") + cwd = os.getcwd() + if parfile and not Path(parfile).exists(): + raise Tempo2ConstructorFailed( + f"parfile not found: {parfile} (cwd: {cwd}). Provide an absolute path or correct relative path." + ) + if timfile and not Path(timfile).exists(): + raise Tempo2ConstructorFailed( + f"timfile not found: {timfile} (cwd: {cwd}). Provide an absolute path or correct relative path." + ) + except Tempo2ConstructorFailed: + # Re-raise to surface a clean single error without retries + raise + except Exception: + # Do not block construction for unexpected preflight errors + logger.debug("Preflight path check skipped due to unexpected error:", exc_info=True) + + # Proactive TOA counting to avoid "Too many TOAs" errors + if self._policy.auto_nobs_retry: + self._proactive_nobs_setup() + + last_exc: Optional[Exception] = None + for attempt in range(1 + self._policy.ctor_retry): + logger.info(f"Construction attempt {attempt + 1}/{self._policy.ctor_retry + 1}") + try: + cmd, require_x86 = _resolve_worker_cmd(self._env_name) + self._require_x86 = require_x86 + logger.debug(f"Resolved worker command: {' '.join(cmd)}") + logger.debug(f"Require x86_64: {require_x86}") + + self._wp = _WorkerProc(self._policy, cmd, require_x86_64=require_x86) + # ctor on the worker (libstempo.tempopulsar) + logger.info("Calling constructor on worker...") + self._wp.ctor(self._ctor_kwargs, preload_residuals=self._policy.preload_residuals) + self._state.created_at = time.time() + self._state.calls_ok = 0 + logger.info(f"Construction successful on attempt {attempt + 1}") + + # Restore state after successful reconstruction + self._restore_state_after_reconstruction() + return + except Exception as e: + logger.warning(f"Construction attempt {attempt + 1} failed: {e}") + last_exc = e + + # Record crash if this is a retry (not the first attempt) + if attempt > 0: + self._record_crash() + + # If it's a file-not-found style error, fail fast without retries + msg = str(e) + if any( + t in msg + for t in ("Cannot find parfile", "Cannot find timfile", "parfile not found", "timfile not found") + ): + logger.error("Input file missing; not retrying constructor.") + break + # kill and retry + try: + if self._wp: + logger.debug("Cleaning up failed worker") + self._wp.close() + except Exception as cleanup_e: + logger.warning(f"Cleanup failed: {cleanup_e}") + pass + self._wp = None + if attempt < self._policy.ctor_retry: # Don't sleep after last attempt + logger.info(f"Waiting {self._policy.ctor_backoff}s before retry...") + time.sleep(self._policy.ctor_backoff) + logger.error(f"All construction attempts failed. Last error: {last_exc}") + raise Tempo2ConstructorFailed(f"tempopulsar ctor failed after retries: {last_exc}") + + def _proactive_nobs_setup(self): + """Proactively count TOAs and add nobs parameter if needed to avoid 'Too many TOAs' errors.""" + try: + timfile = self._ctor_kwargs.get("timfile") + if not timfile: + logger.debug("No timfile specified, skipping proactive nobs setup") + return + + timfile_path = Path(timfile) + if not timfile_path.exists(): + logger.warning(f"TIM file does not exist: {timfile_path}") + return + + logger.info(f"Proactively counting TOAs in {timfile_path}") + analyzer = TimFileAnalyzer() + toa_count = analyzer.count_toas(timfile_path) + + if toa_count > self._policy.nobs_threshold: + maxobs_with_margin = int(toa_count * self._policy.nobs_safety_margin) + self._ctor_kwargs["maxobs"] = maxobs_with_margin + logger.info( + f"Proactively added maxobs={maxobs_with_margin} parameter " + f"(TOAs: {toa_count}, threshold: {self._policy.nobs_threshold}, " + f"margin: {self._policy.nobs_safety_margin})" + ) + else: + logger.debug( + f"TOA count {toa_count} below threshold {self._policy.nobs_threshold}, no maxobs parameter needed" + ) + + except Exception as e: + logger.warning(f"Proactive nobs setup failed: {e}") + # Don't raise - this is just optimization, construction should still work + + # ----------------------------- state management ----------------------------- # + + def _capture_param_state(self, param_name: str, field: str, value: Any) -> None: + """Capture parameter state for crash recovery.""" + if param_name not in self._state.param_cache: + self._state.param_cache[param_name] = {} + self._state.param_cache[param_name][field] = value + logger.debug(f"Captured param state: {param_name}.{field} = {value}") + + def _capture_array_state(self, array_name: str, value: Any) -> None: + """Capture array state for crash recovery.""" + self._state.array_cache[array_name] = value + logger.debug(f"Captured array state: {array_name}") + + def _restore_state_after_reconstruction(self) -> None: + """Restore parameter values, fit flags, and array modifications after worker reconstruction.""" + if not self._state.param_cache and not self._state.array_cache: + logger.debug("No state to restore") + return + + logger.info(f"Restoring state: {len(self._state.param_cache)} params, {len(self._state.array_cache)} arrays") + + # Restore parameter values and fit flags + for param_name, param_state in self._state.param_cache.items(): + for field, value in param_state.items(): + try: + self._wp.set(f"{param_name}.{field}", value) + logger.debug(f"Restored param: {param_name}.{field} = {value}") + except Exception as e: + logger.warning(f"Failed to restore param {param_name}.{field}: {e}") + + # Restore array modifications + for array_name, array_data in self._state.array_cache.items(): + try: + self._wp.setitem(array_name, slice(None), array_data) + logger.debug(f"Restored array: {array_name}") + except Exception as e: + logger.warning(f"Failed to restore array {array_name}: {e}") + + logger.info("State restoration completed") + + def _record_crash(self) -> None: + """Record crash statistics.""" + self._state.crash_count += 1 + self._state.last_crash_at = time.time() + logger.info(f"Worker crash recorded (total crashes: {self._state.crash_count})") + + def get_crash_stats(self) -> Dict[str, Any]: + """Get crash recovery statistics.""" + return { + "crash_count": self._state.crash_count, + "last_crash_at": self._state.last_crash_at, + "worker_age_s": time.time() - self._state.created_at if self._wp else None, + "calls_since_creation": self._state.calls_ok, + } + + # ----------------------------- recycling policy --------------------------- # + + def _should_recycle(self) -> bool: + if self._wp is None: + logger.debug("Should recycle: worker is None") + return True + + age = time.time() - self._state.created_at + + # Check age limit (if set) + if self._policy.max_age_s is not None and age > self._policy.max_age_s: + logger.info(f"Should recycle: worker age {age:.1f}s exceeds max_age_s {self._policy.max_age_s}") + return True + + # Check call limit (if set) + if self._policy.max_calls_per_worker is not None and self._state.calls_ok >= self._policy.max_calls_per_worker: + logger.info( + f"Should recycle: calls_ok {self._state.calls_ok} exceeds " + f"max_calls_per_worker {self._policy.max_calls_per_worker}" + ) + return True + + # Check RSS limit (if set) + if self._policy.rss_soft_limit_mb is not None: + rss = self._wp.rss() + if rss and rss > self._policy.rss_soft_limit_mb: + logger.info(f"Should recycle: RSS {rss}MB exceeds limit {self._policy.rss_soft_limit_mb}MB") + return True + + logger.debug(f"Worker still healthy: age={age:.1f}s, calls={self._state.calls_ok}") + return False + + def _recycle(self): + logger.info("Recycling worker (creating new one)") + if self._wp is not None: + logger.debug("Closing old worker") + with contextlib.suppress(Exception): + self._wp.close() + self._wp = None + logger.debug("Constructing new worker") + self._construct_with_retries() + + # ---------------------------- RPC convenience ----------------------------- # + + def _rpc(self, call: str, **payload): + if self._wp is None: + logger.debug("Worker is None, constructing...") + self._construct_with_retries() + if self._should_recycle(): + logger.info("Worker needs recycling") + self._recycle() + assert self._wp is not None + try: + if call == "get": + out = self._wp.get(payload["name"]) + elif call == "set": + out = self._wp.set(payload["name"], payload["value"]) + elif call == "call": + out = self._wp.call(payload["name"], payload.get("args", ()), payload.get("kwargs", {})) + elif call == "setitem": + out = self._wp.setitem(payload["name"], payload.get("index"), payload.get("value")) + else: + raise Tempo2ProtocolError(f"unknown call {call}") + self._state.calls_ok += 1 + logger.debug(f"RPC {call} successful, total calls: {self._state.calls_ok}") + return out + except (Tempo2Timeout, Tempo2Crashed, Tempo2ProtocolError, Tempo2Error) as e: + # Only log warnings for actual failures, not expected attribute discovery failures + if call != "get" or not str(e).startswith("AttributeError"): + logger.warning(f"RPC {call} failed with {type(e).__name__}: {e}") + logger.info("Attempting automatic worker recycle and retry") + # Record crash and recycle + self._record_crash() + self._recycle() + assert self._wp is not None + if call == "get": + out = self._wp.get(payload["name"]) + elif call == "set": + out = self._wp.set(payload["name"], payload["value"]) + else: + out = self._wp.call(payload["name"], payload.get("args", ()), payload.get("kwargs", {})) + self._state.calls_ok += 1 + if call != "get" or not str(e).startswith("AttributeError"): + logger.info(f"RPC {call} succeeded after recycle, total calls: {self._state.calls_ok}") + return out + + # ------------------------ Attribute proxying magic ------------------------ # + + def __getattr__(self, name: str): + # Filter out IPython-specific attributes to prevent infinite loops + if name.startswith("_ipython_") or name in { + "_ipython_canary_method_should_not_exist_", + "_repr_mimebundle_", + "_repr_html_", + "_repr_json_", + "_repr_latex_", + "_repr_png_", + "_repr_jpeg_", + "_repr_svg_", + "_repr_pdf_", + }: + logger.debug(f"Filtering out IPython attribute: {name}") + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + + def _remote_method(*args, **kwargs): + return self._rpc("call", name=name, args=args, kwargs=kwargs) + + # Non-exceptional discovery using get-kind + kind, payload = self._wp.get_kind(name) + if kind == "value": + # If worker returned a safe libstempo param marker, expose a proxy + if isinstance(payload, dict) and payload.get("__libstempo_param__"): + return _ParamProxy(self, name) + return payload + if kind == "callable": + return _remote_method + if kind == "missing": + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + raise Tempo2ProtocolError(f"unexpected get-kind '{kind}' for attribute '{name}'") + + def __setattr__(self, name: str, value: Any): + if name in tempopulsar.__slots__: + return object.__setattr__(self, name, value) + + # Capture state for crash recovery + self._capture_array_state(name, value) + + _ = self._rpc("set", name=name, value=value) + return None + + def __dir__(self): + """Return a list of available attributes for dir() function.""" + try: + return list(self._wp.dir()) + except Exception: + pass + + # Fallback: return a basic set of common tempopulsar attributes + return [ + "name", + "nobs", + "stoas", + "toaerrs", + "freqs", + "ndim", + "residuals", + "designmatrix", + "toas", + "fit", + "vals", + "errs", + "pars", + "flags", + "flagvals", + "savepar", + "savetim", + "chisq", + "rms", + "ssbfreqs", + "logs", + ] + + # Explicit helpers for common call shapes + def residuals(self, **kwargs): + return self._rpc("call", name="residuals", kwargs=kwargs) + + def designmatrix(self, **kwargs): + return self._rpc("call", name="designmatrix", kwargs=kwargs) + + def toas(self, **kwargs): + return self._rpc("call", name="toas", kwargs=kwargs) + + def fit(self, **kwargs): + return self._rpc("call", name="fit", kwargs=kwargs) + + def logs(self, tail: int = 500) -> str: + return self._wp.logs(tail) if self._wp else "" + + # Mapping-style access to parameters, proxied to libstempo + def __getitem__(self, key: str): + return _ParamProxy(self, key) + + def __setitem__(self, key: str, value: Any): + raise TypeError("Direct assignment to parameters is not supported; set fields like psr['RAJ'].val = x") + + # Expose array-like attributes as write-through proxies + @property + def stoas(self): + return _ArrayProxy(self, "stoas") + + @property + def toaerrs(self): + return _ArrayProxy(self, "toaerrs") + + @property + def freqs(self): + return _ArrayProxy(self, "freqs") + + +class _ParamProxy: + __slots__ = ("_parent", "_name") + + def __init__(self, parent: tempopulsar, name: str) -> None: + object.__setattr__(self, "_parent", parent) + object.__setattr__(self, "_name", name) + + def __getattr__(self, attr: str): + # Fetch field via dotted get path (e.g., RAJ.val), honoring get-kind + kind, payload = self._parent._wp.get_kind(f"{self._name}.{attr}") + if kind == "value": + return payload + if kind == "callable": + # Expose a callable that routes via call with dotted name + def _remote_method(*args, **kwargs): + return self._parent._rpc("call", name=f"{self._name}.{attr}", args=args, kwargs=kwargs) + + return _remote_method + if kind == "missing": + raise AttributeError(f"'{self._name}' has no attribute '{attr}'") + raise Tempo2ProtocolError(f"unexpected get-kind '{kind}' for attribute '{self._name}.{attr}'") + + def __setattr__(self, attr: str, value: Any) -> None: + if attr in _ParamProxy.__slots__: + return object.__setattr__(self, attr, value) + + # Capture parameter state for crash recovery + self._parent._capture_param_state(self._name, attr, value) + + _ = self._parent._rpc("set", name=f"{self._name}.{attr}", value=value) + return None + + def __repr__(self) -> str: + try: + v = self.__getattr__("val") + e = self.__getattr__("err") + return f"" + except Exception: + return f"" + + +class _ArrayProxy: + __slots__ = ("_parent", "_name") + + def __init__(self, parent: tempopulsar, name: str) -> None: + object.__setattr__(self, "_parent", parent) + object.__setattr__(self, "_name", name) + + # numpy reads + def __array__(self, dtype=None): + import numpy as _np + + # Use get-kind to get the array data + kind, payload = self._parent._wp.get_kind(self._name) + if kind == "value": + arr = payload + elif kind == "callable": + # Arrays are not callable; treat as empty + arr = [] + else: + arr = [] + + a = _np.asarray(arr) + if dtype is not None: + a = a.astype(dtype, copy=False) + return a + + def __len__(self): + a = self.__array__() + return int(a.shape[0]) if getattr(a, "ndim", 1) > 0 else 1 + + @property + def shape(self): + return self.__array__().shape + + @property + def dtype(self): + return self.__array__().dtype + + # python indexing + def __getitem__(self, idx): + return self._parent._wp.get_slice(self._name, idx) + + def __setitem__(self, idx, value): + # Capture array state for crash recovery + # For full array replacement (slice(None)), capture the entire array + if idx == slice(None): + self._parent._capture_array_state(self._name, value) + else: + # For partial updates, we need to get the current array and apply the change + # This is more complex, so for now we'll capture the full array after the change + try: + current_array = self.__array__() + if hasattr(current_array, "copy"): + new_array = current_array.copy() + new_array[idx] = value + self._parent._capture_array_state(self._name, new_array) + except Exception: + # If we can't capture the state, continue anyway + pass + + _ = self._parent._rpc("setitem", name=self._name, index=idx, value=value) + return None + + def __repr__(self) -> str: + try: + return repr(self.__array__()) + except Exception: + return f"<_ArrayProxy {self._name}>" + + def __str__(self) -> str: + try: + return str(self.__array__()) + except Exception: + return self.__repr__() + + # Delegate unknown attributes/methods to the numpy array + def __getattr__(self, name: str): + arr = self.__array__() + return getattr(arr, name) + + def __del__(self): + with contextlib.suppress(Exception): + if self._wp is not None: + self._wp.close() + + +# -------------------------- Bulk loader (optional) -------------------------- # + + +@dataclass +class LoadReport: + par: str + tim: Optional[str] + attempts: int + ok: bool + error: Optional[str] = None + retried: bool = False + + +def load_many( + pairs: Iterable[Tuple[str, Optional[str]]], + policy: Optional[Policy] = None, + parallel: int = 8, +) -> Tuple[Dict[str, tempopulsar], Dict[str, LoadReport], List[LoadReport]]: + """ + Bulk-load many pulsars with bounded parallelism. + Returns: (ok_by_name, retried_by_name, failed_list) + + ok_by_name: {psr_name: tempopulsar proxy} + retried_by_name: {psr_name: LoadReport} (those that required >=1 retry) + failed_list: [LoadReport,...] + """ + pol = policy if isinstance(policy, Policy) else Policy() + logger.info(f"Starting bulk load of {len(list(pairs))} pulsars with {parallel} parallel workers") + logger.info(f"Using policy: ctor_retry={pol.ctor_retry}, ctor_backoff={pol.ctor_backoff}s") + + def _one(par, tim): + """Load a single pulsar with retry logic for bulk loading.""" + logger.debug(f"Loading pulsar: par={par}, tim={tim}") + attempts = 0 + report = LoadReport(par=par, tim=tim, attempts=0, ok=False) + last_exc = None + for _ in range(1 + pol.ctor_retry): + attempts += 1 + try: + psr = tempopulsar(parfile=par, timfile=tim, policy=pol) + name = getattr(psr, "name") + report.attempts = attempts + report.ok = True + report.retried = attempts > 1 + logger.info(f"Successfully loaded {name} in {attempts} attempt(s)") + return ("ok", name, psr, report) + except Exception as e: + logger.warning(f"Failed to load {par} (attempt {attempts}): {e}") + last_exc = e + time.sleep(pol.ctor_backoff) + report.attempts = attempts + report.ok = False + report.error = f"{last_exc.__class__.__name__}: {last_exc}" + logger.error(f"Failed to load {par} after {attempts} attempts: {last_exc}") + return ("fail", None, None, report) + + ok: Dict[str, tempopulsar] = {} + retried: Dict[str, LoadReport] = {} + failed: List[LoadReport] = [] + + with ThreadPoolExecutor(max_workers=max(1, parallel)) as ex: + futs = {ex.submit(_one, par, tim): (par, tim) for (par, tim) in pairs} + for fut in as_completed(futs): + kind, name, psr, report = fut.result() + if kind == "ok": + ok[name] = psr + if report.retried: + retried[name] = report + else: + failed.append(report) + + logger.info(f"Bulk load completed: {len(ok)} successful, {len(retried)} retried, {len(failed)} failed") + return ok, retried, failed + + +# ------------------------------- Quick helpers ------------------------------ # + + +def configure_logging(level: str = "INFO", log_file: Optional[str] = None, enable_console: bool = True): + """ + Configure standard logging for the sandbox. + + Args: + level: Log level ("DEBUG", "INFO", "WARNING", "ERROR") + log_file: Optional file path to log to + enable_console: Whether to log to console + """ + # Get the sandbox logger + sandbox_logger = logging.getLogger(__name__) + + # Clear existing handlers + sandbox_logger.handlers.clear() + + # Set level + numeric_level = getattr(logging, level.upper(), logging.INFO) + sandbox_logger.setLevel(numeric_level) + + # Create formatter + formatter = logging.Formatter("%(asctime)s | %(levelname)-8s | tempo2_sandbox | %(message)s", datefmt="%H:%M:%S") + + # Add console handler if requested + if enable_console: + console_handler = logging.StreamHandler(sys.stderr) + console_handler.setFormatter(formatter) + sandbox_logger.addHandler(console_handler) + + # Add file handler if requested + if log_file: + file_handler = logging.FileHandler(log_file) + file_handler.setFormatter(formatter) + sandbox_logger.addHandler(file_handler) + + logger.info(f"Logging configured: level={level}, console={enable_console}, file={log_file}") + + +def setup_instructions(env_name: str = "tempo2_intel"): + """Print setup instructions for creating a tempo2 environment. + + This is a utility function to help users set up their environment + for using the sandbox with different Python environments. + """ + print("Setup instructions for environment '{}':".format(env_name)) + print("\n1. Conda (recommended):") + print(f" conda create -n {env_name} python=3.11") + print(f" conda activate {env_name}") + print(" conda install -c conda-forge tempo2 libstempo") + print(f' # then just: psr = tempopulsar(..., env_name="{env_name}")') + print("\n2. Virtual Environment (Rosetta):") + print(f" arch -x86_64 /usr/local/bin/python3 -m venv ~/.venvs/{env_name}") + print(f" source ~/.venvs/{env_name}/bin/activate") + print(" pip install tempo2 libstempo") + print(f' # then just: psr = tempopulsar(..., env_name="{env_name}")') + print("\n3. System Python with Rosetta:") + print(" # Install Intel Python first (or use system one under arch).") + print(' # You can force Rosetta via TEMPO2_SANDBOX_WORKER_ARCH_PREFIX="arch -x86_64"') + print(' # then: psr = tempopulsar(..., env_name="arch")') + + +def detect_and_guide(env_name: str): + """Detect environment type and provide guidance for setup. + + This is a utility function to help users understand what type + of environment they have and how to use it with the sandbox. + """ + et = _detect_environment_type(env_name) + print(f"Environment detection for '{env_name}': {et}") + if et.startswith("conda:"): + print("✅ Conda env detected; just use env_name as given.") + elif et == "venv": + p = _find_venv_python_path(env_name) + if p: + print(f"✅ venv detected at {p}") + else: + print("❌ venv name matched, but python path not resolved.") + elif et == "arch": + print("✅ Rosetta/system arch mode will be used.") + elif et == "python": + print("✅ Using explicit Python path.") + else: + print("❌ Not found. Use conda env name, venv name, 'arch', or 'python:/abs/python'.") + + +# ------------------------------ Module runner ------------------------------- # + +if __name__ == "__main__": + # If executed directly, act as worker (useful for manual debugging): + _worker_stdio_main() diff --git a/libstempo/tim_file_analyzer.py b/libstempo/tim_file_analyzer.py new file mode 100644 index 0000000..23ba0cd --- /dev/null +++ b/libstempo/tim_file_analyzer.py @@ -0,0 +1,534 @@ +"""TimFileAnalyzer - Fast TIM file analyzer without PINT dependencies. + +This module provides a lightweight class to quickly extract TOA MJD values +from TIM files using independent parsing logic that replicates PINT's functionality +without requiring PINT as a dependency. This is useful for environments where +PINT is not available or when you need a minimal implementation. + +The implementation replicates the core functionality from PINT's toa.py module +for parsing TOA lines and handling TIM file commands. + +Author: Rutger van Haasteren -- rutger@vhaasteren.com +Date: 2025-10-10 +""" + +import re +import logging +from pathlib import Path +from typing import List, Set, Dict, Tuple, Optional + +logger = logging.getLogger(__name__) + + +# TOA commands that can appear in TIM files +TOA_COMMANDS = ( + "DITHER", + "EFAC", + "EMAX", + "EMAP", + "EMIN", + "EQUAD", + "FMAX", + "FMIN", + "INCLUDE", + "INFO", + "JUMP", + "MODE", + "NOSKIP", + "PHA1", + "PHA2", + "PHASE", + "SEARCH", + "SIGMA", + "SIM", + "SKIP", + "TIME", + "TRACK", + "ZAWGT", + "FORMAT", + "END", +) + +# Simple observatory name mapping (subset of PINT's observatory registry) +# This covers the most common observatories used in TOA files +OBSERVATORY_NAMES = { + # Common single-letter codes + "A": "Arecibo", + "AO": "Arecibo", + "ARECIBO": "Arecibo", + "B": "GBT", + "GBT": "GBT", + "GREEN_BANK": "GBT", + "C": "CHIME", + "CHIME": "CHIME", + "D": "DSS-43", + "E": "Effelsberg", + "EFFELSBERG": "Effelsberg", + "F": "FAST", + "FAST": "FAST", + "G": "GMRT", + "GMRT": "GMRT", + "H": "Hobart", + "HOBART": "Hobart", + "I": "IAR", + "IAR": "IAR", + "J": "Jodrell", + "JB": "Jodrell", + "JODRELL": "Jodrell", + "K": "Kalyazin", + "KALYAZIN": "Kalyazin", + "L": "Lovell", + "LOVELL": "Lovell", + "M": "MeerKAT", + "MEERKAT": "MeerKAT", + "N": "Nancay", + "NANCAY": "Nancay", + "O": "Ooty", + "OOTY": "Ooty", + "P": "Parkes", + "PARKES": "Parkes", + "Q": "Qitai", + "QITAI": "Qitai", + "R": "RATAN", + "RATAN": "RATAN", + "S": "Sardinia", + "SARDINIA": "Sardinia", + "T": "Tianma", + "TIANMA": "Tianma", + "U": "URUMQI", + "URUMQI": "URUMQI", + "V": "VLA", + "VLA": "VLA", + "W": "Westerbork", + "WESTERBORK": "Westerbork", + "X": "Xinjiang", + "XINJIANG": "Xinjiang", + "Y": "Yunnan", + "YUNNAN": "Yunnan", + "Z": "Zelenchukskaya", + "ZELENCHUKSKAYA": "Zelenchukskaya", + # Special codes + "@": "Barycenter", + "BARYCENTER": "Barycenter", + "SSB": "Barycenter", + # Spacecraft observatories + "FERMI": "Fermi", + "NICER": "NICER", + "SWIFT": "Swift", + "RXTE": "RXTE", + "XTE": "RXTE", +} + + +def _toa_format(line: str, fmt: str = "Unknown") -> str: + """Determine the type of a TOA line. + + Identifies a TOA line as one of the following types: + Comment, Command, Blank, Tempo2, Princeton, ITOA, Parkes, Unknown. + + This replicates PINT's _toa_format function. + """ + # Check for comments first + if ( + line.startswith("C ") + and len(line) > 2 + and not line[2].isdigit() # C followed by non-digit + or line.startswith("c ") + and len(line) > 2 + and not line[2].isdigit() # c followed by non-digit + or line.startswith("#") + or line.startswith("CC ") + ): + return "Comment" + + # Check for commands + if line.upper().lstrip().startswith(TOA_COMMANDS): + return "Command" + + # Check for blank lines + if re.match(r"^\s*$", line): + return "Blank" + + # Check for Princeton format: starts with single observatory code followed by space + if re.match(r"[0-9a-zA-Z@] ", line): + return "Princeton" + + # Check for Tempo2 format: long lines, explicitly marked as Tempo2, or structured like Tempo2 + if len(line) > 80 or fmt == "Tempo2": + return "Tempo2" + + # Additional Tempo2 detection: if it looks like a Tempo2 TOA line (has 5+ space-separated fields) + fields = line.split() + if len(fields) >= 5: + # Check if it looks like a Tempo2 TOA line: name freq mjd error obs [flags...] + try: + # Try to parse as Tempo2: name freq mjd error obs + float(fields[1]) # frequency should be numeric + float(fields[2]) # MJD should be numeric + float(fields[3]) # error should be numeric + # If we get here, it looks like Tempo2 format + return "Tempo2" + except (ValueError, IndexError): + pass + + # Check for Parkes format: starts with space, has decimal at position 42 + if re.match(r"^ ", line) and len(line) > 41 and line[41] == ".": + return "Parkes" + + # Check for ITOA format: two non-space chars, decimal at position 15 + if re.match(r"\S\S", line) and len(line) > 14 and line[14] == ".": + return "ITOA" + + # Default to Unknown + return "Unknown" + + +def _get_observatory_name(obs_code: str) -> str: + """Get observatory name from observatory code. + + This is a simplified version of PINT's get_observatory function + that only handles the most common observatory codes. + + Args: + obs_code: Observatory code (e.g., 'A', 'AO', '@') + + Returns: + Observatory name + """ + obs_code_upper = obs_code.upper() + return OBSERVATORY_NAMES.get(obs_code_upper, obs_code_upper) + + +def _parse_TOA_line(line: str, fmt: str = "Unknown") -> Tuple[Optional[Tuple[int, float]], dict]: + """Parse a one-line ASCII time-of-arrival. + + Return an MJD tuple and a dictionary of other TOA information. + The format can be one of: Comment, Command, Blank, Tempo2, + Princeton, ITOA, Parkes, or Unknown. + + This replicates PINT's _parse_TOA_line function. + """ + MJD = None + fmt = _toa_format(line, fmt) + d = dict(format=fmt) + + if fmt == "Princeton": + # Princeton format + # ---------------- + # columns item + # 1-1 Observatory (one-character code) '@' is barycenter + # 2-2 must be blank + # 16-24 Observing frequency (MHz) + # 25-44 TOA (decimal point must be in column 30 or column 31) + # 45-53 TOA uncertainty (microseconds) + # 69-78 DM correction (pc cm^-3) + try: + # Handle both fixed-width and space-separated Princeton format + if len(line) >= 78: # Fixed-width format + d["obs"] = _get_observatory_name(line[0].upper()) + d["freq"] = float(line[15:24]) + d["error"] = float(line[44:53]) + ii, ff = line[24:44].split(".") + ii = int(ii) + # For very old TOAs, see https://tempo.sourceforge.net/ref_man_sections/toa.txt + if ii < 40000: + ii += 39126 + MJD = (ii, float(f"0.{ff}")) + try: + d["ddm"] = str(float(line[68:78])) + except ValueError: + d["ddm"] = str(0.0) + else: # Space-separated format (fallback) + fields = line.split() + if len(fields) >= 4: + d["obs"] = _get_observatory_name(fields[0].upper()) + d["freq"] = float(fields[1]) + d["error"] = float(fields[3]) + # Parse MJD + if "." in fields[2]: + ii, ff = fields[2].split(".") + ii = int(ii) + if ii < 40000: + ii += 39126 + MJD = (ii, float(f"0.{ff}")) + else: + ii = int(fields[2]) + if ii < 40000: + ii += 39126 + MJD = (ii, 0.0) + d["ddm"] = str(0.0) # Default DM correction + else: + raise ValueError("Not enough fields for Princeton format") + except (ValueError, IndexError) as e: + # If parsing fails, treat as unknown format + logger.debug(f"Failed to parse Princeton format line: {e}") + d["format"] = "Unknown" + + elif fmt == "Tempo2": + # This could use more error catching... + try: + fields = line.split() + d["name"] = fields[0] + d["freq"] = float(fields[1]) + if "." in fields[2]: + ii, ff = fields[2].split(".") + MJD = (int(ii), float(f"0.{ff}")) + else: + MJD = (int(fields[2]), 0.0) + d["error"] = float(fields[3]) + d["obs"] = _get_observatory_name(fields[4].upper()) + # All the rest should be flags + flags = fields[5:] + + # Flags and flag-values should be given in pairs. + # The for loop below will fail otherwise. + if len(flags) % 2 != 0: + raise ValueError( + f"Flags and flag-values should be given in pairs. The given flags are {' '.join(flags)}" + ) + + for i in range(0, len(flags), 2): + k, v = flags[i].lstrip("-"), flags[i + 1] + if k in ["error", "freq", "scale", "MJD", "flags", "obs", "name"]: + raise ValueError(f"TOA flag ({k}) will overwrite TOA parameter!") + if not k: + raise ValueError(f"The string {repr(flags[i])} is not a valid flag") + d[k] = v + except (ValueError, IndexError) as e: + # If parsing fails, treat as unknown format + logger.debug(f"Failed to parse Tempo2 format line: {e}") + d["format"] = "Unknown" + + elif fmt == "Command": + d[fmt] = line.split() + elif fmt == "Parkes": + """ + columns item + 1-1 Must be blank + 26-34 Observing Frequency (MHz) + 35-55 TOA (decimal point must be in column 42) + 56-63 Phase offset (fraction of P0, added to TOA) + 64-71 TOA uncertainty + 80-80 Observatory (1 character) + """ + try: + d["name"] = line[1:25] + d["freq"] = float(line[25:34]) + ii = line[34:41] + ff = line[42:55] + MJD = int(ii), float(f"0.{ff}") + phaseoffset = float(line[55:62]) + if phaseoffset != 0: + raise ValueError(f"Cannot interpret Parkes format with phaseoffset={phaseoffset} yet") + d["error"] = float(line[63:71]) + d["obs"] = _get_observatory_name(line[79].upper()) + except (ValueError, IndexError) as e: + # If parsing fails, treat as unknown format + logger.debug(f"Failed to parse Parkes format line: {e}") + d["format"] = "Unknown" + + elif fmt == "ITOA": + raise RuntimeError(f"TOA format '{fmt}' not implemented yet") + elif fmt not in ["Blank", "Comment"]: + raise RuntimeError(f"Unable to identify TOA format for line {line!r}, expecting {fmt}") + return MJD, d + + +class TimFileAnalyzer: + """Fast TIM file analyzer for timespan calculation without PINT dependencies. + + This class efficiently extracts TOA MJD values from TIM files using + independent parsing logic that replicates PINT's functionality, + providing both performance and robustness for timespan calculations. + + The analyzer caches results per file to avoid duplicate parsing when both + timespan and TOA count are needed for the same file. + """ + + def __init__(self): + """Initialize the TIM file analyzer.""" + self.logger = logger + self._processed_files: Set[Path] = set() + # Cache for storing timespan and TOA counts per file (not MJD values to save memory) + self._file_cache: Dict[Path, Tuple[float, int]] = {} + + def _get_timespan_and_count(self, tim_file_path: Path) -> Tuple[float, int]: + """Get timespan and TOA count from TIM file, using cache if available. + + Args: + tim_file_path: Path to the TIM file + + Returns: + Tuple of (timespan_in_days, toa_count) + """ + # Check cache first + if tim_file_path in self._file_cache: + self.logger.debug(f"Using cached data for {tim_file_path}") + return self._file_cache[tim_file_path] + + try: + self._processed_files.clear() + mjd_values = self._extract_mjd_values_recursive(tim_file_path) + toa_count = len(mjd_values) + + if toa_count == 0: + timespan = 0.0 + else: + # Calculate timespan as max - min (no need to sort) + timespan = float(max(mjd_values) - min(mjd_values)) + + # Cache only the results we need (timespan and count) + self._file_cache[tim_file_path] = (timespan, toa_count) + + if toa_count > 0: + self.logger.debug(f"Cached data for {tim_file_path}: {timespan:.1f} days, {toa_count} TOAs") + else: + self.logger.debug(f"Cached data for {tim_file_path}: No TOAs found") + return timespan, toa_count + + except Exception as e: + self.logger.warning(f"Parsing failed for {tim_file_path}: {e}") + self.logger.debug("File may contain non-standard TIM format or malformed data") + # Cache empty result to avoid repeated failures + empty_result = (0.0, 0) + self._file_cache[tim_file_path] = empty_result + return empty_result + + def calculate_timespan(self, tim_file_path: Path) -> float: + """Calculate timespan from TIM file using independent parsing logic. + + Args: + tim_file_path: Path to the TIM file + + Returns: + Timespan in days (max(mjd) - min(mjd)) + """ + timespan, toa_count = self._get_timespan_and_count(tim_file_path) + + if toa_count == 0: + self.logger.warning(f"No TOAs found in {tim_file_path}") + return 0.0 + + self.logger.debug(f"Timespan for {tim_file_path}: {timespan:.1f} days ({toa_count} TOAs)") + return timespan + + def count_toas(self, tim_file_path: Path) -> int: + """Count the number of TOAs in a TIM file. + + Args: + tim_file_path: Path to the TIM file + + Returns: + Number of TOAs found in the file + """ + _, toa_count = self._get_timespan_and_count(tim_file_path) + + self.logger.debug(f"TOA count for {tim_file_path}: {toa_count} TOAs") + return toa_count + + def clear_cache(self) -> None: + """Clear the file cache.""" + self._file_cache.clear() + self.logger.debug("File cache cleared") + + def get_timespan_and_count(self, tim_file_path: Path) -> Tuple[float, int]: + """Get both timespan and TOA count efficiently using cached data. + + Args: + tim_file_path: Path to the TIM file + + Returns: + Tuple of (timespan_in_days, toa_count) + """ + timespan, toa_count = self._get_timespan_and_count(tim_file_path) + + if toa_count == 0: + self.logger.warning(f"No TOAs found in {tim_file_path}") + return 0.0, 0 + + self.logger.debug(f"Timespan and count for {tim_file_path}: {timespan:.1f} days, {toa_count} TOAs") + return timespan, toa_count + + def _extract_mjd_values_recursive(self, tim_file_path: Path) -> List[float]: + """Recursively extract MJD values from TIM file and included files. + + Args: + tim_file_path: Path to the TIM file + + Returns: + List of MJD values from all TOA lines + """ + mjd_values = [] + + # Avoid infinite recursion + if tim_file_path in self._processed_files: + self.logger.warning(f"Circular INCLUDE detected: {tim_file_path}") + return mjd_values + + self._processed_files.add(tim_file_path) + + try: + with open(tim_file_path, "r") as f: + for line_num, line in enumerate(f, 1): + line = line.strip() + + # Skip empty lines + if not line: + continue + + # Use independent parsing for both TOA lines and commands + try: + mjd_tuple, d = _parse_TOA_line(line) + except Exception as e: + # Parsing may fail on malformed lines - skip them gracefully + self.logger.debug(f"Skipping malformed line in {tim_file_path}: {line.strip()} - {e}") + continue + + # Handle commands (especially INCLUDE) + if d["format"] == "Command": + self._handle_command(d, tim_file_path, mjd_values) + continue + + # Skip non-TOA lines + if d["format"] in ("Comment", "Blank", "Unknown"): + continue + + # Extract MJD from TOA line + if mjd_tuple is not None: + # Convert (int, float) tuple to float MJD + mjd_value = float(mjd_tuple[0]) + float(mjd_tuple[1]) + mjd_values.append(mjd_value) + + except Exception as e: + self.logger.error(f"Error reading TIM file {tim_file_path}: {e}") + + return mjd_values + + def _handle_command(self, d: dict, current_file: Path, mjd_values: List[float]) -> None: + """Handle TIM file commands using parsed command data. + + Args: + d: Parsed command dictionary + current_file: Current TIM file being processed + mjd_values: List to extend with MJD values from included files + """ + if d["format"] != "Command": + return + + cmd = d["Command"][0].upper() + + if cmd == "INCLUDE": + if len(d["Command"]) < 2: + self.logger.warning(f"INCLUDE command without filename: {d['Command']}") + return + + include_file = d["Command"][1] + include_path = current_file.parent / include_file + + if include_path.exists(): + self.logger.debug(f"Processing included TOA file {include_path}") + included_mjds = self._extract_mjd_values_recursive(include_path) + mjd_values.extend(included_mjds) + else: + self.logger.warning(f"INCLUDE file not found: {include_path}") + # Other commands don't affect timespan calculation diff --git a/pyproject.toml b/pyproject.toml index 091ce44..73822d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,8 +25,7 @@ description = "A Python wrapper for tempo2" authors = [{name = "Michele Vallisneri", email = "vallis@vallis.org"}] urls = { Homepage = "https://github.com/vallis/libstempo" } readme = "README.md" -license = "MIT" -license-files = [ "LICENSE" ] +license = {file = "LICENSE"} classifiers=[ "Intended Audience :: Developers", "Intended Audience :: Science/Research", diff --git a/tests/test_imports.py b/tests/test_imports.py index 90f76ca..ba964c6 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -6,3 +6,5 @@ def test_imports(): import libstempo.toasim # noqa:F401 import libstempo.eccUtils # noqa: F401 import libstempo.spharmORFbasis # noqa: F401 + import libstempo.sandbox # noqa: F401 + import libstempo.tim_file_analyzer # noqa: F401 diff --git a/tests/test_sandbox.py b/tests/test_sandbox.py new file mode 100644 index 0000000..9c345ac --- /dev/null +++ b/tests/test_sandbox.py @@ -0,0 +1,475 @@ +import unittest +import libstempo as t2 +from libstempo.sandbox import tempopulsar, Policy, configure_logging +from libstempo.tim_file_analyzer import TimFileAnalyzer +import tempfile +import numpy as np +from numpy.testing import assert_allclose + + +class TestSandbox(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.data_path = t2.__path__[0] + "/data/" + cls.parfile = cls.data_path + "J1909-3744_NANOGrav_dfg+12.par" + cls.timfile = cls.data_path + "J1909-3744_NANOGrav_dfg+12.tim" + + def test_basic_sandbox_usage(self): + """Test basic sandbox functionality""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile) + self.assertEqual(psr.name, "1909-3744") + self.assertEqual(psr.nobs, 1001) + + def test_policy_configuration(self): + """Test Policy configuration""" + policy = Policy(ctor_retry=2, call_timeout_s=30.0) + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile, policy=policy) + self.assertEqual(psr.name, "1909-3744") + + def test_designmatrix_call(self): + """Test calling designmatrix through sandbox""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile) + dmat = psr.designmatrix() + self.assertEqual(dmat.shape, (1001, 83)) + + def test_attribute_access(self): + """Test accessing attributes through sandbox""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile) + self.assertEqual(psr.name, "1909-3744") + self.assertEqual(psr.nobs, 1001) + self.assertEqual(len(psr.stoas), 1001) + + def test_logging_configuration(self): + """Test logging configuration""" + # This should not raise an exception + configure_logging(level="DEBUG", enable_console=False) + configure_logging(level="INFO", enable_console=True) + + def test_logs_readout(self): + """Test that logs() captures tempo2 output via savepar().""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile, dofit=False) + # Baseline logs + logs_before = psr.logs(2000) + self.assertIsInstance(logs_before, str) + # Invoke an operation that emits tempo2 text output + tmp_par = tempfile.NamedTemporaryFile(delete=True) + tmp_par.close() + _ = psr.savepar(tmp_par.name) + # Give background drain thread a moment to process + import time as _t + + _t.sleep(0.1) + logs_after = psr.logs(8000) + self.assertIsInstance(logs_after, str) + # Expect noticeable output; check growth and presence of a known token + self.assertGreater(len(logs_after), len(logs_before)) + self.assertIn("Results for PSR", logs_after) + + def test_sandbox_native_parity(self): + """Compare key attributes and outputs between sandbox and native tempopulsar.""" + psr_s = tempopulsar(parfile=self.parfile, timfile=self.timfile) + psr_n = t2.tempopulsar(parfile=self.parfile, timfile=self.timfile) + + # Scalar/string attributes + self.assertEqual(psr_s.name, psr_n.name) + self.assertEqual(psr_s.nobs, psr_n.nobs) + + # Arrays: use a tight tolerance + assert_allclose(psr_s.stoas, psr_n.stoas, rtol=0, atol=0) + + # Residuals + res_s = psr_s.residuals() + res_n = psr_n.residuals() + self.assertEqual(res_s.shape, res_n.shape) + assert_allclose(res_s, res_n, rtol=0, atol=0) + + # Design matrix + dm_s = psr_s.designmatrix() + dm_n = psr_n.designmatrix() + self.assertEqual(dm_s.shape, dm_n.shape) + assert_allclose(dm_s, dm_n, rtol=0, atol=0) + + # TOAs + toas_s = psr_s.toas() + toas_n = psr_n.toas() + self.assertEqual(len(toas_s), len(toas_n)) + assert_allclose(np.asarray(toas_s), np.asarray(toas_n), rtol=0, atol=0) + + # Timing model parameters (subset): ensure values and metadata match + s_all = set(psr_s.pars(which="all")) + n_all = set(psr_n.pars(which="all")) + s_fit = set(psr_s.pars()) # defaults to fitted + n_fit = set(psr_n.pars()) + s_set = set(psr_s.pars(which="set")) + n_set = set(psr_n.pars(which="set")) + + for par_name in ["RAJ", "DECJ"]: + if par_name in s_all and par_name in n_all: + # numeric values and errors via bulk accessors + s_val = np.asarray(psr_s.vals(which=[par_name]))[0] + n_val = np.asarray(psr_n.vals(which=[par_name]))[0] + s_err = np.asarray(psr_s.errs(which=[par_name]))[0] + n_err = np.asarray(psr_n.errs(which=[par_name]))[0] + assert_allclose(s_val, n_val, rtol=0, atol=0) + assert_allclose(s_err, n_err, rtol=0, atol=0) + # fit/set flags via pars() groups + self.assertEqual(par_name in s_fit, par_name in n_fit) + self.assertEqual(par_name in s_set, par_name in n_set) + + def test_param_proxy_accessors(self): + """Test psr['parname'].val/err/fit/set mapping accessors and roundtrips.""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile) + + # Choose parameters that are present and safe to touch + par_val = "RAJ" + par_fit = "DM" # commonly present and safe to toggle fit flag + + # Read val/err via mapping + v0 = psr[par_val].val + e0 = psr[par_val].err + self.assertIsInstance(float(v0), float) + self.assertIsInstance(float(e0), float) + + # Roundtrip val by setting the same value (as Python float) + psr[par_val].val = float(v0) + self.assertAlmostEqual(float(psr[par_val].val), float(v0), places=12) + + # Roundtrip err by setting the same value (as Python float) + psr[par_val].err = float(e0) + self.assertAlmostEqual(float(psr[par_val].err), float(e0), places=12) + + # Toggle fit flag and revert + fit0 = bool(psr[par_fit].fit) + psr[par_fit].fit = not fit0 + self.assertEqual(bool(psr[par_fit].fit), (not fit0)) + # revert + psr[par_fit].fit = fit0 + self.assertEqual(bool(psr[par_fit].fit), fit0) + + # 'set' flag should be boolean and readable; do not change it here + self.assertIsInstance(bool(psr[par_val].set), bool) + + def test_stoas_edit_and_fit_matches_native(self): + """Edit stoas and toaerrs, run fit, and compare residuals to native.""" + rng = np.random.default_rng(12345) + + # Sandbox and native + psr_s = tempopulsar(parfile=self.parfile, timfile=self.timfile) + psr_n = t2.tempopulsar(parfile=self.parfile, timfile=self.timfile) + + # Create identical noise realization + noise = 0.1e-6 * rng.standard_normal(psr_s.nobs) / 86400.0 + + # Apply to stoas and toaerrs in both + # Sandbox: use write-through proxies (backward compatible API) + psr_s.stoas[:] = psr_s.stoas[:] + noise + psr_s.toaerrs[:] = 0.1 + + # Native + psr_n.stoas[:] = psr_n.stoas + noise + psr_n.toaerrs[:] = 0.1 + + # Fit both + _ = psr_s.fit() + _ = psr_n.fit() + + # Compare residuals tightly + res_s = psr_s.residuals() + res_n = psr_n.residuals() + self.assertEqual(res_s.shape, res_n.shape) + assert_allclose(res_s, res_n, rtol=0, atol=0) + + def test_param_attribute_proxy_no_pickling_and_roundtrip(self): + """Access psr.RAJ attribute (check pickling error) and roundtrip fields.""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile, dofit=False) + + # This attribute access used to force pickling of a non-picklable param object. + p = psr.RAJ + + # Read primitives + v0 = float(p.val) + e0 = float(p.err) + f0 = bool(p.fit) + + # Sanity on types + self.assertIsInstance(v0, float) + self.assertIsInstance(e0, float) + self.assertIsInstance(f0, bool) + + # Roundtrip same values (ensures proxy->worker set path works) + p.val = v0 + self.assertAlmostEqual(float(psr.RAJ.val), v0, places=12) + + p.err = e0 + self.assertAlmostEqual(float(psr.RAJ.err), e0, places=12) + + # Toggle fit and revert + p.fit = not f0 + self.assertEqual(bool(psr.RAJ.fit), (not f0)) + p.fit = f0 + self.assertEqual(bool(psr.RAJ.fit), f0) + + # Proxy should be printable + _ = repr(p) + _ = str(p) + + def test_param_attribute_vs_mapping_consistency(self): + """Ensure attribute-style psr.RAJ and mapping psr['RAJ'] remain consistent.""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile, dofit=False) + + # Values and errors agree between attribute and mapping APIs + self.assertAlmostEqual(float(psr.RAJ.val), float(psr["RAJ"].val), places=12) + self.assertAlmostEqual(float(psr.RAJ.err), float(psr["RAJ"].err), places=12) + + # Setting via attribute reflects in mapping + new_val = float(psr.RAJ.val) + psr.RAJ.val = new_val + self.assertAlmostEqual(float(psr["RAJ"].val), new_val, places=12) + + # Setting via mapping reflects in attribute + new_err = float(psr["RAJ"].err) + psr["RAJ"].err = new_err + self.assertAlmostEqual(float(psr.RAJ.err), new_err, places=12) + + +class TestStateManagement(unittest.TestCase): + """Tests for state management and crash recovery.""" + + @classmethod + def setUpClass(cls): + cls.data_path = t2.__path__[0] + "/data/" + cls.parfile = cls.data_path + "J1909-3744_NANOGrav_dfg+12.par" + cls.timfile = cls.data_path + "J1909-3744_NANOGrav_dfg+12.tim" + + def test_param_state_capture(self): + """Test that parameter modifications are captured in state cache.""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile) + + # Modify parameters + original_raj = psr["RAJ"].val + psr["RAJ"].val = original_raj + 0.001 + psr["DM"].fit = False + + # Check state cache + self.assertIn("RAJ", psr._state.param_cache) + self.assertIn("val", psr._state.param_cache["RAJ"]) + self.assertEqual(psr._state.param_cache["RAJ"]["val"], original_raj + 0.001) + self.assertEqual(psr._state.param_cache["DM"]["fit"], False) + + def test_array_state_capture(self): + """Test that array modifications are captured in state cache.""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile) + + # Modify arrays + original_stoas = psr.stoas.copy() + psr.stoas[:] = original_stoas + 1e-6 + + # Check state cache + self.assertIn("stoas", psr._state.array_cache) + expected_stoas = original_stoas + 1e-6 + np.testing.assert_allclose(psr._state.array_cache["stoas"], expected_stoas) + + def test_state_restoration_after_recycle(self): + """Test that state is restored after worker recycle.""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile) + + # Modify state + psr["RAJ"].val = 5.020 + psr["DM"].fit = True + original_stoas = psr.stoas.copy() + psr.stoas[:] = original_stoas + 2e-6 + + # Force recycle + psr._recycle() + + # Verify restoration + self.assertAlmostEqual(psr["RAJ"].val, 5.020, places=6) + self.assertTrue(psr["DM"].fit) + np.testing.assert_allclose(psr.stoas, original_stoas + 2e-6, rtol=1e-10) + + def test_crash_statistics_tracking(self): + """Test crash statistics tracking.""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile) + + # Initial stats + stats = psr.get_crash_stats() + self.assertEqual(stats["crash_count"], 0) + self.assertIsNone(stats["last_crash_at"]) + + # Record crash manually + psr._record_crash() + + # Check stats + stats = psr.get_crash_stats() + self.assertEqual(stats["crash_count"], 1) + self.assertIsNotNone(stats["last_crash_at"]) + + # Record another crash + psr._record_crash() + stats = psr.get_crash_stats() + self.assertEqual(stats["crash_count"], 2) + + def test_crash_stats_after_recycle(self): + """Test crash statistics after recycle.""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile) + + # Record crash and recycle + psr._record_crash() + psr._recycle() + + # Check stats + stats = psr.get_crash_stats() + self.assertEqual(stats["crash_count"], 1) + self.assertIsNotNone(stats["last_crash_at"]) + self.assertGreater(stats["worker_age_s"], 0) + + def test_complete_crash_recovery_workflow(self): + """Test complete workflow: modify -> crash -> recover -> verify.""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile) + + # Step 1: Modify state + original_raj = psr["RAJ"].val + original_dm_fit = psr["DM"].fit + original_stoas = psr.stoas.copy() + + psr["RAJ"].val = original_raj + 0.005 + psr["DM"].fit = not original_dm_fit + psr.stoas[:] = original_stoas + 5e-6 + + # Step 2: Simulate crash and recovery + psr._record_crash() + psr._recycle() + + # Step 3: Verify state restoration + self.assertAlmostEqual(psr["RAJ"].val, original_raj + 0.005, places=6) + self.assertEqual(psr["DM"].fit, not original_dm_fit) + np.testing.assert_allclose(psr.stoas, original_stoas + 5e-6, rtol=1e-10) + + # Step 4: Verify crash stats + stats = psr.get_crash_stats() + self.assertEqual(stats["crash_count"], 1) + self.assertIsNotNone(stats["last_crash_at"]) + + def test_state_preservation_across_multiple_crashes(self): + """Test state preservation across multiple crashes.""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile) + + # Initial modifications + psr["RAJ"].val = 5.025 + psr["DM"].fit = False + + # Multiple crashes and recoveries + for i in range(3): + psr._record_crash() + psr._recycle() + + # Verify state preserved + self.assertAlmostEqual(psr["RAJ"].val, 5.025, places=6) + self.assertFalse(psr["DM"].fit) + + # Check final crash count + stats = psr.get_crash_stats() + self.assertEqual(stats["crash_count"], 3) + + def test_state_restoration_with_invalid_parameters(self): + """Test state restoration when some parameters are invalid.""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile) + + # Modify valid parameter + psr["RAJ"].val = 5.030 + + # Add invalid parameter to cache (simulate edge case) + psr._state.param_cache["INVALID_PARAM"] = {"val": 999.0} + + # Recycle and verify valid parameter restored + psr._recycle() + self.assertAlmostEqual(psr["RAJ"].val, 5.030, places=6) + + def test_empty_state_cache_restoration(self): + """Test restoration when state cache is empty.""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile) + + # Ensure empty cache + psr._state.param_cache.clear() + psr._state.array_cache.clear() + + # Recycle should not fail + psr._recycle() + + # Basic functionality should still work + self.assertEqual(psr.name, "1909-3744") + + def test_state_capture_performance(self): + """Test that state capture doesn't significantly impact performance.""" + import time + + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile) + + # Time parameter modifications + start = time.time() + for i in range(100): + psr["RAJ"].val = 5.0 + i * 0.001 + param_time = time.time() - start + + # Time array modifications + start = time.time() + for i in range(10): + # Get the array, modify it, and set it back + current_stoas = psr.stoas.copy() + psr.stoas[:] = current_stoas + i * 1e-6 + array_time = time.time() - start + + # Should be reasonable (adjust thresholds as needed) + self.assertLess(param_time, 5.0) # 100 param changes in < 5 seconds + self.assertLess(array_time, 10.0) # 10 array changes in < 10 seconds + + +class TestTimFileAnalyzer(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.data_path = t2.__path__[0] + "/data/" + cls.timfile = cls.data_path + "J1909-3744_NANOGrav_dfg+12.tim" + + def test_toa_count(self): + """Test TOA counting functionality""" + analyzer = TimFileAnalyzer() + count = analyzer.count_toas(self.timfile) + self.assertEqual(count, 1001) + + def test_timespan_calculation(self): + """Test timespan calculation""" + analyzer = TimFileAnalyzer() + timespan = analyzer.calculate_timespan(self.timfile) + self.assertGreater(timespan, 0) + self.assertIsInstance(timespan, float) + + def test_combined_analysis(self): + """Test getting both timespan and count""" + analyzer = TimFileAnalyzer() + timespan, count = analyzer.get_timespan_and_count(self.timfile) + self.assertEqual(count, 1001) + self.assertGreater(timespan, 0) + self.assertIsInstance(timespan, float) + + def test_cache_functionality(self): + """Test that caching works correctly""" + analyzer = TimFileAnalyzer() + + # First call + timespan1, count1 = analyzer.get_timespan_and_count(self.timfile) + + # Second call should use cache + timespan2, count2 = analyzer.get_timespan_and_count(self.timfile) + + self.assertEqual(timespan1, timespan2) + self.assertEqual(count1, count2) + + # Clear cache and test again + analyzer.clear_cache() + timespan3, count3 = analyzer.get_timespan_and_count(self.timfile) + self.assertEqual(timespan1, timespan3) + self.assertEqual(count1, count3) + + +if __name__ == "__main__": + unittest.main()