From 498b1366559e5d1e83f2b9c09e68031830d60e7d Mon Sep 17 00:00:00 2001 From: Rutger van Haasteren Date: Fri, 10 Oct 2025 17:20:16 +0000 Subject: [PATCH 01/16] feat: Added sandbox libstempo interface --- MANIFEST.in | 2 + README.md | 99 +++ libstempo/__init__.py | 3 + libstempo/sandbox.py | 1319 ++++++++++++++++++++++++++++++++ libstempo/tim_file_analyzer.py | 548 +++++++++++++ tests/test_imports.py | 2 + tests/test_sandbox.py | 96 +++ 7 files changed, 2069 insertions(+) create mode 100644 libstempo/sandbox.py create mode 100644 libstempo/tim_file_analyzer.py create mode 100644 tests/test_sandbox.py diff --git a/MANIFEST.in b/MANIFEST.in index ed72932..23068a1 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -16,6 +16,8 @@ include libstempo/multinest.py include libstempo/plot.py include libstempo/spharmORFbasis.py include libstempo/toasim.py +include libstempo/sandbox.py +include libstempo/tim_file_analyzer.py include libstempo/ecc_vs_nharm.txt include demo/libstempo-demo.ipynb include demo/libstempo-toasim-demo.ipynb diff --git a/README.md b/README.md index 5e44f70..b287af0 100644 --- a/README.md +++ b/README.md @@ -64,3 +64,102 @@ pip install libstempo ## Usage See [Demo Notebook 1](https://github.com/vallis/libstempo/blob/master/demo/libstempo-demo.ipynb) for basic usage and [Demo Notebook 2](https://github.com/vallis/libstempo/blob/master/demo/libstempo-toasim-demo.ipynb) for simulation usage. + +## Sandbox Mode (Crash-Protected) + +libstempo includes a sandbox mode that provides crash isolation and automatic retry capabilities. This is particularly useful when working with problematic pulsars or long-running analyses where tempo2 crashes are common. + +### Basic Usage + +The sandbox provides a drop-in replacement for the standard `tempopulsar` class: + +```python +from libstempo.sandbox import tempopulsar + +# Basic usage - same API as regular tempopulsar +psr = tempopulsar(parfile="J1713.par", timfile="J1713.tim", dofit=False) +residuals = psr.residuals() +design_matrix = psr.designmatrix() +``` + +### Advanced Configuration + +```python +from libstempo.sandbox import tempopulsar, Policy, configure_logging + +# Configure logging for debugging +configure_logging(level="DEBUG", log_file="tempo2.log") + +# Configure retry and timeout policies +policy = Policy( + ctor_retry=5, # Retry constructor 5 times on failure + call_timeout_s=300.0, # 5-minute timeout per RPC call + max_calls_per_worker=1000, # Recycle worker after 1000 calls + max_age_s=3600, # Recycle worker after 1 hour + rss_soft_limit_mb=2048 # Recycle worker if memory exceeds 2GB +) + +psr = tempopulsar(parfile="J1713.par", timfile="J1713.tim", policy=policy) +``` + +### Environment Support + +The sandbox supports different Python environments: + +```python +# Use virtual/conda environment +psr = tempopulsar(parfile="J1713.par", timfile="J1713.tim", env_name="tempo2_intel") + +# Use system Python with Rosetta (macOS) +psr = tempopulsar(parfile="J1713.par", timfile="J1713.tim", env_name="arch") + +# Use explicit Python path +psr = tempopulsar(parfile="J1713.par", timfile="J1713.tim", env_name="python:/path/to/python") +``` + +### Key Benefits + +- **Crash Isolation**: Segfaults in tempo2 only kill the worker process, not your main kernel +- **Automatic Retry**: Built-in retry logic for transient failures +- **Worker Recycling**: Prevents memory leaks and resource accumulation +- **Environment Flexibility**: Support for conda, venv, and Rosetta environments +- **Enhanced Logging**: Comprehensive logging for debugging and monitoring +- **Proactive TOA Handling**: Automatically handles large TOA files to prevent "Too many TOAs" errors + +### Performance + +The sandbox adds ~9x initialization overhead but only ~1.2x overhead for computational operations like `residuals()` and `designmatrix()`. For heavy computations, the overhead becomes negligible relative to the actual work. Use sandbox when stability is critical, direct libstempo when performance is paramount. + +### Bulk Loading + +For processing many pulsars: + +```python +from libstempo.sandbox import load_many, Policy + +pairs = [("J1713.par", "J1713.tim"), ("J1909.par", "J1909.tim"), ...] +policy = Policy(ctor_retry=3, call_timeout_s=120.0) + +ok_by_name, retried_by_name, failed_list = load_many(pairs, policy=policy, parallel=8) + +print(f"Successfully loaded: {len(ok_by_name)}") +print(f"Required retries: {len(retried_by_name)}") +print(f"Failed: {len(failed_list)}") +``` + +### Error Handling + +The sandbox defines specific exception types: + +```python +from libstempo.sandbox import Tempo2Error, Tempo2Crashed, Tempo2Timeout + +try: + psr = tempopulsar(parfile="problematic.par", timfile="problematic.tim") +except Tempo2Crashed: + print("Worker process crashed - likely a segfault") +except Tempo2Timeout: + print("Worker timed out") +except Tempo2Error as e: + print(f"Sandbox error: {e}") +``` diff --git a/libstempo/__init__.py b/libstempo/__init__.py index e672189..8969fae 100644 --- a/libstempo/__init__.py +++ b/libstempo/__init__.py @@ -1,6 +1,9 @@ import os from ._find_tempo2 import find_tempo2_runtime +# Import sandbox functionality +from .sandbox import tempopulsar as sandbox_tempopulsar, Policy, configure_logging +from .tim_file_analyzer import TimFileAnalyzer # check to see if TEMPO2 environment variable is set TEMPO2_RUNTIME = os.getenv("TEMPO2") diff --git a/libstempo/sandbox.py b/libstempo/sandbox.py new file mode 100644 index 0000000..e7c53b1 --- /dev/null +++ b/libstempo/sandbox.py @@ -0,0 +1,1319 @@ +# sandbox.py +""" +Process sandbox for libstempo/tempo2 that keeps each pulsar in its own clean +subprocess. A segfault in tempo2/libstempo only kills the worker, not your kernel. + +Usage (drop-in): + from sandbox import tempopulsar + psr = tempopulsar(parfile="J1713.par", timfile="J1713.tim", dofit=False) + r = psr.residuals() + +Advanced with logging: + from sandbox import tempopulsar, configure_logging, Policy + configure_logging(level="DEBUG", log_file="tempo2.log") + policy = Policy(ctor_retry=5, call_timeout_s=300.0) + psr = tempopulsar(parfile="J1713.par", timfile="J1713.tim", policy=policy) + +With specific environment: + psr = tempopulsar(parfile="J1713.par", timfile="J1713.tim", env_name="myenv") + # or for conda: env_name="mycondaenv" + # or explicit path: env_name="python:/path/to/python" + +With persistent workers (no recycling/timeouts): + policy = Policy( + call_timeout_s=None, # No RPC timeouts + max_calls_per_worker=None, # Never recycle by call count + max_age_s=None, # Never recycle by age + rss_soft_limit_mb=None # Never recycle by memory + ) + psr = tempopulsar(parfile="J1713.par", timfile="J1713.tim", policy=policy) + +Advanced: + from sandbox import load_many, Policy + ok, retried, failed = load_many([("J1713.par","J1713.tim"), ...], policy=Policy()) + +Environment selection (Apple Silicon + Rosetta etc.): + psr = tempopulsar(..., env_name="tempo2_intel") # conda env + psr = tempopulsar(..., env_name="myvenv") # venv (~/.venvs/myvenv, etc.) + psr = tempopulsar(..., env_name="arch") # system python via Rosetta (arch -x86_64) + psr = tempopulsar(..., env_name="python:/abs/python") # explicit Python path + +You can force Rosetta prefix via env var: + TEMPO2_SANDBOX_WORKER_ARCH_PREFIX="arch -x86_64" + +Logging: + The sandbox includes comprehensive loguru logging for debugging and monitoring. + Use configure_logging() to set up logging levels and outputs. Logs include: + - Worker process lifecycle (creation, recycling, termination) + - RPC call details and timing + - Constructor retry attempts and failures + - Memory usage and recycling decisions + - Error details and recovery attempts + +Robustness: + The sandbox suppresses libstempo debug output during construction + to prevent interference with the JSON-RPC protocol. This ensures reliable + communication even when libstempo prints diagnostic messages. The suppression + works at the OS file descriptor level to catch output from C libraries. +""" + +from __future__ import annotations + +import base64 +import contextlib +import dataclasses +import json +import os +import pickle +import platform +import select +import shutil +import signal +import subprocess +import sys +import time +import traceback +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Tuple + +# Import TimFileAnalyzer for proactive TOA counting +from .tim_file_analyzer import TimFileAnalyzer + +# Standard logging +import logging + +logger = logging.getLogger(__name__) + +# ---------------------------- Public Exceptions ---------------------------- # + + +class Tempo2Error(Exception): + """Base class for sandbox errors.""" + + +class Tempo2Crashed(Tempo2Error): + """The worker process crashed or died unexpectedly (likely a segfault).""" + + +class Tempo2Timeout(Tempo2Error): + """The worker did not reply in time; it was terminated.""" + + +class Tempo2ProtocolError(Tempo2Error): + """Malformed RPC request/response or other IPC failure.""" + + +class Tempo2ConstructorFailed(Tempo2Error): + """Constructor failed even after retries.""" + + +# ------------------------------- Policy knobs ----------------------------- # +@dataclass(frozen=True) +class Policy: + """Configuration policy for sandbox worker behavior and lifecycle management. + + Controls retry behavior, timeouts, and worker recycling policies. + """ + + # Constructor protection + ctor_retry: int = 5 # number of extra tries after the first + ctor_backoff: float = 0.75 # seconds between ctor retries + preload_residuals: bool = False # call residuals() once after ctor + preload_designmatrix: bool = False # call designmatrix() once after ctor + preload_toas: bool = False # call toas() once after ctor + preload_fit: bool = False # call fit() once after ctor + + # RPC protection + call_timeout_s: Optional[float] = ( + None # per-call timeout (seconds), None = no timeout + ) + kill_grace_s: float = 2.0 # after timeout, wait before SIGKILL + + # Recycling / hygiene + max_calls_per_worker: Optional[int] = ( + None # recycle after this many good calls, None = never recycle by calls + ) + max_age_s: Optional[float] = ( + None # recycle after this many seconds, None = never recycle by age + ) + rss_soft_limit_mb: Optional[int] = None # if provided, recycle when beaten + + # Proactive TOA handling for large files + auto_nobs_retry: bool = True # automatically add nobs parameter for large TOA files + nobs_threshold: int = 10000 # add nobs parameter if TOA count exceeds this threshold + nobs_safety_margin: float = 1.1 # multiplier for nobs parameter (e.g., 1.1 = 10% more than actual count) + + +# -------------------------- Wire serialization helpers --------------------- # + +# We send JSON-RPC 2.0 frames. To avoid JSON-encoding numpy arrays and +# cross-arch issues, params/result travel as base64-encoded cloudpickle blobs. + +try: + import cloudpickle as _cp # best-effort; falls back to pickle if missing +except Exception: + _cp = pickle + + +def _b64_dumps_py(obj: Any) -> str: + """Serialize Python object to base64-encoded string using cloudpickle.""" + return base64.b64encode(_cp.dumps(obj)).decode("ascii") + + +def _b64_loads_py(s: str) -> Any: + """Deserialize base64-encoded string to Python object using cloudpickle.""" + return _cp.loads(base64.b64decode(s.encode("ascii"))) + + +def _format_exc_tuple() -> Tuple[str, str, str]: + """Format current exception info as tuple of (type_name, message, traceback).""" + et, ev, tb = sys.exc_info() + name = et.__name__ if et else "Exception" + return (name, str(ev), "".join(traceback.format_exception(et, ev, tb))) + + +def _current_rss_mb_portable() -> Optional[int]: + """Get current process RSS memory usage in MB, portable across platforms.""" + try: + if sys.platform.startswith("linux"): + with open("/proc/self/statm") as f: + pages = int(f.read().split()[1]) + rss = pages * (os.sysconf("SC_PAGE_SIZE") // 1024 // 1024) + return rss + except Exception: + pass + try: + import psutil # type: ignore + + return int(psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)) + except Exception: + return None + + +# ----------------------------- Worker (stdio) ------------------------------ # + + +def _worker_stdio_main() -> None: + """ + Runs inside the worker interpreter (possibly Rosetta x86_64). + Protocol: + 1) Immediately print a single 'hello' JSON line with environment info. + 2) Then serve JSON-RPC 2.0 requests line-by-line on stdin/stdout. + Methods: ctor, get, set, call, del, rss, bye + Each request's 'params_b64' is a pickled dict of parameters. + Each response uses 'result_b64' for Python results, or 'error'. + """ + # Step 1: hello handshake + hello = { + "hello": { + "python": sys.version.split()[0], + "executable": sys.executable, + "machine": platform.machine(), + "platform": platform.platform(), + "has_libstempo": False, + "tempo2_version": None, + } + } + try: + try: + from libstempo import tempopulsar as _lib_tempopulsar # noqa + import numpy as _np # noqa + + hello["hello"]["has_libstempo"] = True + # best-effort tempo2 version probe + try: + from libstempo import tempo2 # type: ignore + + hello["hello"]["tempo2_version"] = getattr( + tempo2, "TEMPO2_VERSION", None + ) + except Exception: + pass + except Exception: + pass + finally: + sys.stdout.write(json.dumps(hello) + "\n") + sys.stdout.flush() + + # If libstempo failed to import at hello, try once more here to return clean errors + try: + from libstempo import tempopulsar as _lib_tempopulsar # noqa + import numpy as _np # noqa + except Exception: + # Keep serving, but report on first request + _lib_tempopulsar: Optional[Any] = None + _np: Optional[Any] = None + + obj = None + + def _write_response(resp: Dict[str, Any]) -> None: + """Write JSON response to stdout and flush.""" + sys.stdout.write(json.dumps(resp) + "\n") + sys.stdout.flush() + + # JSON-RPC loop + for line in sys.stdin: + line = line.strip() + if not line: + continue + try: + req = json.loads(line) + except Exception: + _write_response( + { + "jsonrpc": "2.0", + "id": None, + "error": {"code": -32700, "message": "parse error"}, + } + ) + continue + + rid = req.get("id", None) + method = req.get("method", "") + params_b64 = req.get("params_b64", None) + + # Decode params dict if present + params = {} + if params_b64 is not None: + try: + params = _b64_loads_py(params_b64) + if not isinstance(params, dict): + raise TypeError("params_b64 must decode to dict") + except Exception: + et, ev, tb = _format_exc_tuple() + _write_response( + { + "jsonrpc": "2.0", + "id": rid, + "error": { + "code": -32602, + "message": f"invalid params: {ev}", + "data": tb, + }, + } + ) + continue + + # Handle methods + try: + if method == "bye": + _write_response( + {"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py("bye")} + ) + return + + if method == "rss": + rss = _current_rss_mb_portable() + _write_response( + {"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py(rss)} + ) + continue + + if method == "ctor": + if _lib_tempopulsar is None: + raise ImportError("libstempo not available in worker") + + # Suppress stdout/stderr during constructor to prevent libstempo debug output + # from contaminating the JSON-RPC protocol. We need to redirect at the OS level + # because tempo2 writes directly to file descriptors. + import os + + # Save original stdout/stderr file descriptors + original_stdout = os.dup(1) + original_stderr = os.dup(2) + + try: + # Redirect stdout/stderr to /dev/null + devnull = os.open(os.devnull, os.O_WRONLY) + os.dup2(devnull, 1) # stdout + os.dup2(devnull, 2) # stderr + + obj = _lib_tempopulsar(**params["kwargs"]) + if params.get("preload_residuals", True): + _ = obj.residuals(updatebats=True, formresiduals=True) + + finally: + # Restore original stdout/stderr + os.dup2(original_stdout, 1) + os.dup2(original_stderr, 2) + os.close(devnull) + os.close(original_stdout) + os.close(original_stderr) + + _write_response( + { + "jsonrpc": "2.0", + "id": rid, + "result_b64": _b64_dumps_py("constructed"), + } + ) + continue + + if obj is None: + raise RuntimeError("object not constructed") + + if method == "get": + name = params["name"] + val = getattr(obj, name) + # copy numpy views to decouple from lib memory + try: + import numpy as _np2 # local alias + + if hasattr(val, "base") and isinstance(val, _np2.ndarray): + val = val.copy() + except Exception: + pass + _write_response( + {"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py(val)} + ) + continue + + if method == "set": + name, value = params["name"], params["value"] + setattr(obj, name, value) + _write_response( + {"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py(None)} + ) + continue + + if method == "call": + name = params["name"] + args = tuple(params.get("args", ())) + kwargs = dict(params.get("kwargs", {})) + meth = getattr(obj, name) + out = meth(*args, **kwargs) + try: + import numpy as _np2 + + if hasattr(out, "base") and isinstance(out, _np2.ndarray): + out = out.copy() + except Exception: + pass + _write_response( + {"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py(out)} + ) + continue + + if method == "del": + try: + del obj + except Exception: + pass + obj = None + _write_response( + {"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py(None)} + ) + continue + + _write_response( + { + "jsonrpc": "2.0", + "id": rid, + "error": {"code": -32601, "message": f"method not found: {method}"}, + } + ) + except Exception: + et, ev, tb = _format_exc_tuple() + _write_response( + { + "jsonrpc": "2.0", + "id": rid, + "error": {"code": -32000, "message": f"{et}: {ev}", "data": tb}, + } + ) + + +# ------------------------------ Subprocess client -------------------------- # + + +class _WorkerProc: + """ + JSON-RPC over stdio subprocess. + Launches the worker in the requested environment (conda/venv/arch/system). + """ + + def __init__(self, policy: Policy, cmd: List[str], require_x86_64: bool = False): + self.policy = policy + self.cmd = cmd + self.proc: Optional[subprocess.Popen] = None + self._id = 0 + logger.info(f"Creating worker process with command: {' '.join(cmd)}") + logger.info(f"Require x86_64 architecture: {require_x86_64}") + self._start(require_x86_64=require_x86_64) + + # ---------- process management ---------- + + def _start(self, require_x86_64: bool = False): + logger.debug("Starting worker subprocess...") + self._hard_kill() # just in case + + # Ensure unbuffered text I/O + env = os.environ.copy() + env.setdefault("PYTHONUNBUFFERED", "1") + + logger.debug( + f"Launching subprocess with environment: PYTHONUNBUFFERED={env.get('PYTHONUNBUFFERED')}" + ) + logger.debug(f"Subprocess working directory: {os.getcwd()}") + self.proc = subprocess.Popen( + self.cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, # line buffered + cwd=os.getcwd(), # Explicitly set working directory + ) + + logger.debug(f"Worker process started with PID: {self.proc.pid}") + + # Hello handshake (one line of JSON) + logger.debug("Waiting for worker hello handshake...") + hello = self._readline_with_timeout(self.policy.call_timeout_s) + if hello is None: + if self.policy.call_timeout_s is None: + logger.error("Worker did not send hello - worker disconnected") + self._hard_kill() + raise Tempo2Crashed("worker did not send hello - worker disconnected") + else: + logger.error("Worker did not send hello in time") + self._hard_kill() + raise Tempo2Timeout("worker did not send hello in time") + + try: + hello_obj = json.loads(hello) + except Exception as e: + logger.error(f"Failed to parse worker hello: {e}") + self._hard_kill() + raise Tempo2ProtocolError(f"malformed hello: {hello!r}") + + info = hello_obj.get("hello", {}) + logger.info(f"Worker hello received: {info}") + + if require_x86_64: + if str(info.get("machine", "")).lower() != "x86_64": + logger.error( + f"Architecture mismatch: worker is {info.get('machine')}, but x86_64 required" + ) + self._hard_kill() + raise Tempo2Error( + f"worker arch is {info.get('machine')}, but x86_64 is required for quad precision" + ) + + if not info.get("has_libstempo", False): + logger.error("libstempo not available in worker environment") + # Keep the worker up; subsequent ctor will return a clean error, + # but we can already warn here to fail fast. + self._hard_kill() + raise Tempo2Error( + "libstempo is not importable inside the selected environment. " + f"Worker executable: {info.get('executable')}" + ) + + self.birth = time.time() + self.calls_ok = 0 + logger.info(f"Worker ready and initialized (PID: {self.proc.pid})") + + def _readline_with_timeout(self, timeout: Optional[float]) -> Optional[str]: + if self.proc is None or self.proc.stdout is None: + return None + + if timeout is None: + # No timeout - wait indefinitely + while True: + rlist, _, _ = select.select([self.proc.stdout], [], []) + if rlist: + line = self.proc.stdout.readline() + if not line: # EOF + return None + return line.rstrip("\n") + else: + # With timeout + end = time.time() + timeout + while time.time() < end: + rlist, _, _ = select.select( + [self.proc.stdout], [], [], max(0.01, end - time.time()) + ) + if rlist: + line = self.proc.stdout.readline() + if not line: # EOF + return None + return line.rstrip("\n") + return None + + def _hard_kill(self): + if self.proc and self.proc.poll() is None: + logger.warning(f"Hard killing worker process (PID: {self.proc.pid})") + try: + self.proc.terminate() + except Exception as e: + logger.warning(f"Failed to terminate process: {e}") + pass + t0 = time.time() + while ( + self.proc.poll() is None + and (time.time() - t0) < self.policy.kill_grace_s + ): + time.sleep(0.01) + if self.proc.poll() is None: + logger.warning( + f"Sending SIGKILL to worker process (PID: {self.proc.pid})" + ) + with contextlib.suppress(Exception): + os.kill(self.proc.pid, signal.SIGKILL) + self.proc = None + + def close(self): + logger.debug("Closing worker process...") + if self.proc and self.proc.poll() is None: + try: + logger.debug("Sending bye RPC to worker") + self._send_rpc("bye", {}) + # ignore response; we're closing anyway + except Exception as e: + logger.debug(f"Bye RPC failed (expected): {e}") + pass + self._hard_kill() + logger.debug("Worker process closed") + + def __del__(self): + with contextlib.suppress(Exception): + self.close() + + # ---------- JSON-RPC helpers ---------- + + def _send_rpc( + self, method: str, params: Dict[str, Any], timeout: Optional[float] = None + ) -> Any: + if self.proc is None or self.proc.stdin is None or self.proc.stdout is None: + logger.error("Worker not running, cannot send RPC") + raise Tempo2Crashed("worker not running") + + self._id += 1 + rid = self._id + logger.debug(f"Sending RPC {method} (id: {rid})") + + frame = { + "jsonrpc": "2.0", + "id": rid, + "method": method, + "params_b64": _b64_dumps_py(params), + } + line = json.dumps(frame) + "\n" + + try: + self.proc.stdin.write(line) + self.proc.stdin.flush() + except Exception as e: + logger.error(f"Failed to send RPC {method}: {e}") + self._hard_kill() + raise Tempo2Crashed(f"send failed: {e!r}") + + # Wait for response + t = self.policy.call_timeout_s if timeout is None else timeout + if t is None: + logger.debug(f"Waiting for RPC {method} response (no timeout)") + else: + logger.debug(f"Waiting for RPC {method} response (timeout: {t}s)") + resp_line = self._readline_with_timeout(t) + if resp_line is None: + if t is None: + logger.error(f"RPC {method} failed - worker disconnected") + self._hard_kill() + raise Tempo2Crashed(f"RPC '{method}' failed - worker disconnected") + else: + logger.error(f"RPC {method} timed out after {t}s") + self._hard_kill() + raise Tempo2Timeout(f"RPC '{method}' timed out") + + try: + resp = json.loads(resp_line) + except Exception as e: + logger.error(f"Failed to parse RPC {method} response: {e}") + self._hard_kill() + raise Tempo2ProtocolError(f"malformed response: {resp_line!r}") + + if resp.get("id") != rid: + logger.error( + f"RPC {method} id mismatch: expected {rid}, got {resp.get('id')}" + ) + self._hard_kill() + raise Tempo2ProtocolError( + f"mismatched id in response: {resp.get('id')} vs {rid}" + ) + + if "error" in resp and resp["error"] is not None: + err = resp["error"] + msg = err.get("message", "error") + data = err.get("data", "") + logger.error(f"RPC {method} failed: {msg}") + raise Tempo2Error(f"{msg}\n{data}") + + logger.debug(f"RPC {method} completed successfully") + result_b64 = resp.get("result_b64", None) + return _b64_loads_py(result_b64) if result_b64 is not None else None + + # Public RPCs + def ctor(self, kwargs: Dict[str, Any], preload_residuals: bool): + logger.info(f"Constructing tempopulsar with kwargs: {kwargs}") + logger.info(f"Preload residuals: {preload_residuals}") + return self._send_rpc( + "ctor", {"kwargs": kwargs, "preload_residuals": preload_residuals} + ) + + def get(self, name: str): + logger.debug(f"Getting attribute: {name}") + return self._send_rpc("get", {"name": name}) + + def set(self, name: str, value: Any): + logger.debug(f"Setting attribute: {name}") + return self._send_rpc("set", {"name": name, "value": value}) + + def call(self, name: str, args=(), kwargs=None): + logger.debug(f"Calling method: {name} with args={args}, kwargs={kwargs}") + return self._send_rpc( + "call", {"name": name, "args": tuple(args), "kwargs": dict(kwargs or {})} + ) + + def rss(self) -> Optional[int]: + try: + logger.debug("Getting worker RSS memory usage") + return self._send_rpc("rss", {}) + except Exception as e: + logger.warning(f"Failed to get RSS: {e}") + return None + + +# ------------------------- Command resolution (env_name) -------------------- # + + +def _detect_environment_type(env_name: str) -> str: + """ + Return "conda", "venv", "arch", "python", or "unknown". + """ + if env_name.startswith("python:"): + return "python" + + # conda family + for tool in ("conda", "mamba", "micromamba"): + try: + r = subprocess.run( + [tool, "run", "-n", env_name, "python", "--version"], + capture_output=True, + text=True, + timeout=5, + ) + if r.returncode == 0: + return "conda:" + tool + except Exception: + pass + + # common venv locations + venv_paths = [ + Path.home() / ".venvs" / env_name / "bin" / "python", + Path.home() / "venvs" / env_name / "bin" / "python", + Path.home() / ".virtualenvs" / env_name / "bin" / "python", + Path.cwd() / env_name / "bin" / "python", + Path.cwd() / ".venv" / "bin" / "python", # only if env_name == '.venv' + # Additional common locations for containers/dev environments + Path("/opt/venvs") / env_name / "bin" / "python", + Path("/opt/virtualenvs") / env_name / "bin" / "python", + Path("/usr/local/venvs") / env_name / "bin" / "python", + Path("/home") / "venvs" / env_name / "bin" / "python", + # Try to find any python executable with the env name in the path + Path(f"/opt/venvs/{env_name}/bin/python"), + Path(f"/opt/virtualenvs/{env_name}/bin/python"), + ] + for p in venv_paths: + if p.exists(): + return "venv" + + if env_name in ("arch", "rosetta", "system"): + return "arch" + + return "unknown" + + +def _find_venv_python_path(env_name: str) -> Optional[str]: + venv_paths = [ + Path.home() / ".venvs" / env_name / "bin" / "python", + Path.home() / "venvs" / env_name / "bin" / "python", + Path.home() / ".virtualenvs" / env_name / "bin" / "python", + Path.cwd() / env_name / "bin" / "python", + Path.cwd() / ".venv" / "bin" / "python", + # Additional common locations for containers/dev environments + Path("/opt/venvs") / env_name / "bin" / "python", + Path("/opt/virtualenvs") / env_name / "bin" / "python", + Path("/usr/local/venvs") / env_name / "bin" / "python", + Path("/home") / "venvs" / env_name / "bin" / "python", + # Try to find any python executable with the env name in the path + Path(f"/opt/venvs/{env_name}/bin/python"), + Path(f"/opt/virtualenvs/{env_name}/bin/python"), + ] + for p in venv_paths: + if p.exists(): + return str(p) + return None + + +def _resolve_worker_cmd(env_name: Optional[str]) -> Tuple[List[str], bool]: + """ + Build the subprocess command to run the worker and whether we require x86_64. + Returns (cmd, require_x86_64) + """ + + # Base invocation that runs this file in worker mode: + # Find the src directory dynamically + current_file = Path(__file__).resolve() + src_dir = ( + current_file.parent.parent + ) # Go up from libstempo/sandbox.py to src/ + src_path = str(src_dir) + + def python_to_worker_cmd(python_exe: str) -> List[str]: + """Build command to run worker with given Python executable.""" + return [ + python_exe, + "-c", + f"import sys; sys.path.insert(0, '{src_path}'); import libstempo.sandbox as m; m._worker_stdio_main()", + ] + + arch_prefix_env = os.environ.get("TEMPO2_SANDBOX_WORKER_ARCH_PREFIX", "").strip() + require_x86_64 = False + + # No env_name -> use current python (no Rosetta) + if env_name is None: + py = sys.executable + return (python_to_worker_cmd(py), False) + + # Explicit python path + if env_name.startswith("python:"): + py = env_name.split(":", 1)[1] + return (python_to_worker_cmd(py), False) + + etype = _detect_environment_type(env_name) + + # conda/mamba/micromamba + if etype.startswith("conda:"): + tool = etype.split(":", 1)[1] + cmd = [ + tool, + "run", + "-n", + env_name, + "python", + "-c", + f"import sys; sys.path.insert(0, '{src_path}'); import libstempo.sandbox as m; m._worker_stdio_main()", + ] + # Choosing to require x86_64 only if user *explicitly* asks via arch prefix or env_name == "arch" + require_x86_64 = "arch" in env_name.lower() + if arch_prefix_env: + cmd = arch_prefix_env.split() + cmd + require_x86_64 = True + return (cmd, require_x86_64) + + # venv + if etype == "venv": + py = _find_venv_python_path(env_name) + if not py: + raise Tempo2Error(f"virtualenv '{env_name}' not found in common locations") + cmd = python_to_worker_cmd(py) + if arch_prefix_env: + cmd = arch_prefix_env.split() + cmd + require_x86_64 = True + return (cmd, require_x86_64) + + # system Rosetta + if etype == "arch": + # try system python (python3 or python) + py = shutil.which("python3") or shutil.which("python") + if not py: + raise Tempo2Error("could not find system python for arch mode") + arch = arch_prefix_env.split() if arch_prefix_env else ["arch", "-x86_64"] + require_x86_64 = True + return (arch + python_to_worker_cmd(py), require_x86_64) + + raise Tempo2Error( + f"Environment '{env_name}' not found. " + "Use a conda env name, a venv name, 'arch', or 'python:/abs/python'." + ) + + +# ------------------------------ Public proxy ------------------------------- # + + +@dataclasses.dataclass +class _State: + """Internal state tracking for tempopulsar proxy instances.""" + + created_at: float + calls_ok: int + + +class tempopulsar: + """ + Proxy for libstempo.tempopulsar living inside an isolated subprocess. + + This class provides a drop-in replacement for libstempo.tempopulsar that runs + in a separate process to prevent crashes from affecting the main kernel. + All constructor arguments are forwarded to libstempo.tempopulsar unchanged. + + The proxy automatically handles: + - Worker process lifecycle management + - Automatic retry on failures + - Worker recycling based on age, call count, or memory usage + - JSON-RPC communication over stdio + + Args: + env_name: Environment name (conda env or venv name, 'arch', or 'python:/abs/python'). + If None (default), uses the current Python environment. + policy: Optional Policy instance to configure worker behavior + **kwargs: Additional arguments passed to libstempo.tempopulsar + + Example: + >>> psr = tempopulsar(parfile="J1713.par", timfile="J1713.tim", dofit=False) + >>> residuals = psr.residuals() + >>> design_matrix = psr.designmatrix() + """ + + __slots__ = ( + "_policy", + "_wp", + "_state", + "_ctor_kwargs", + "_env_name", + "_require_x86", + ) + + def __init__(self, env_name: Optional[str] = None, **kwargs): + policy = kwargs.pop("policy", None) + self._policy: Policy = policy if isinstance(policy, Policy) else Policy() + self._env_name = env_name + self._ctor_kwargs = dict(kwargs) + self._wp: Optional[_WorkerProc] = None + self._state = _State(created_at=time.time(), calls_ok=0) + self._require_x86 = False + + logger.info( + f"Creating tempopulsar with env_name='{env_name}', kwargs={self._ctor_kwargs}" + ) + logger.info( + f"Using policy: ctor_retry={self._policy.ctor_retry}, ctor_backoff={self._policy.ctor_backoff}s" + ) + self._construct_with_retries() + + # --------------- construction / reconstruction with retries --------------- # + + def _construct_with_retries(self): + logger.info( + f"Starting construction with {self._policy.ctor_retry + 1} total attempts" + ) + + # Proactive TOA counting to avoid "Too many TOAs" errors + if self._policy.auto_nobs_retry: + self._proactive_nobs_setup() + + last_exc: Optional[Exception] = None + for attempt in range(1 + self._policy.ctor_retry): + logger.info( + f"Construction attempt {attempt + 1}/{self._policy.ctor_retry + 1}" + ) + try: + cmd, require_x86 = _resolve_worker_cmd(self._env_name) + self._require_x86 = require_x86 + logger.debug(f"Resolved worker command: {' '.join(cmd)}") + logger.debug(f"Require x86_64: {require_x86}") + + self._wp = _WorkerProc(self._policy, cmd, require_x86_64=require_x86) + # ctor on the worker (libstempo.tempopulsar) + logger.info("Calling constructor on worker...") + self._wp.ctor( + self._ctor_kwargs, preload_residuals=self._policy.preload_residuals + ) + self._state.created_at = time.time() + self._state.calls_ok = 0 + logger.info(f"Construction successful on attempt {attempt + 1}") + return + except Exception as e: + logger.warning(f"Construction attempt {attempt + 1} failed: {e}") + last_exc = e + # kill and retry + try: + if self._wp: + logger.debug("Cleaning up failed worker") + self._wp.close() + except Exception as cleanup_e: + logger.warning(f"Cleanup failed: {cleanup_e}") + pass + self._wp = None + if attempt < self._policy.ctor_retry: # Don't sleep after last attempt + logger.info(f"Waiting {self._policy.ctor_backoff}s before retry...") + time.sleep(self._policy.ctor_backoff) + logger.error(f"All construction attempts failed. Last error: {last_exc}") + raise Tempo2ConstructorFailed( + f"tempopulsar ctor failed after retries: {last_exc}" + ) + + def _proactive_nobs_setup(self): + """Proactively count TOAs and add nobs parameter if needed to avoid 'Too many TOAs' errors.""" + try: + timfile = self._ctor_kwargs.get('timfile') + if not timfile: + logger.debug("No timfile specified, skipping proactive nobs setup") + return + + timfile_path = Path(timfile) + if not timfile_path.exists(): + logger.warning(f"TIM file does not exist: {timfile_path}") + return + + logger.info(f"Proactively counting TOAs in {timfile_path}") + analyzer = TimFileAnalyzer() + toa_count = analyzer.count_toas(timfile_path) + + if toa_count > self._policy.nobs_threshold: + maxobs_with_margin = int(toa_count * self._policy.nobs_safety_margin) + self._ctor_kwargs['maxobs'] = maxobs_with_margin + logger.info(f"Proactively added maxobs={maxobs_with_margin} parameter (TOAs: {toa_count}, threshold: {self._policy.nobs_threshold}, margin: {self._policy.nobs_safety_margin})") + else: + logger.debug(f"TOA count {toa_count} below threshold {self._policy.nobs_threshold}, no maxobs parameter needed") + + except Exception as e: + logger.warning(f"Proactive nobs setup failed: {e}") + # Don't raise - this is just optimization, construction should still work + + # ----------------------------- recycling policy --------------------------- # + + def _should_recycle(self) -> bool: + if self._wp is None: + logger.debug("Should recycle: worker is None") + return True + + age = time.time() - self._state.created_at + + # Check age limit (if set) + if self._policy.max_age_s is not None and age > self._policy.max_age_s: + logger.info( + f"Should recycle: worker age {age:.1f}s exceeds max_age_s {self._policy.max_age_s}" + ) + return True + + # Check call limit (if set) + if ( + self._policy.max_calls_per_worker is not None + and self._state.calls_ok >= self._policy.max_calls_per_worker + ): + logger.info( + f"Should recycle: calls_ok {self._state.calls_ok} exceeds max_calls_per_worker {self._policy.max_calls_per_worker}" + ) + return True + + # Check RSS limit (if set) + if self._policy.rss_soft_limit_mb is not None: + rss = self._wp.rss() + if rss and rss > self._policy.rss_soft_limit_mb: + logger.info( + f"Should recycle: RSS {rss}MB exceeds limit {self._policy.rss_soft_limit_mb}MB" + ) + return True + + logger.debug( + f"Worker still healthy: age={age:.1f}s, calls={self._state.calls_ok}" + ) + return False + + def _recycle(self): + logger.info("Recycling worker (creating new one)") + if self._wp is not None: + logger.debug("Closing old worker") + with contextlib.suppress(Exception): + self._wp.close() + self._wp = None + logger.debug("Constructing new worker") + self._construct_with_retries() + + # ---------------------------- RPC convenience ----------------------------- # + + def _rpc(self, call: str, **payload): + if self._wp is None: + logger.debug("Worker is None, constructing...") + self._construct_with_retries() + if self._should_recycle(): + logger.info("Worker needs recycling") + self._recycle() + assert self._wp is not None + try: + if call == "get": + out = self._wp.get(payload["name"]) + elif call == "set": + out = self._wp.set(payload["name"], payload["value"]) + elif call == "call": + out = self._wp.call( + payload["name"], payload.get("args", ()), payload.get("kwargs", {}) + ) + else: + raise Tempo2ProtocolError(f"unknown call {call}") + self._state.calls_ok += 1 + logger.debug(f"RPC {call} successful, total calls: {self._state.calls_ok}") + return out + except (Tempo2Timeout, Tempo2Crashed, Tempo2ProtocolError, Tempo2Error) as e: + logger.warning(f"RPC {call} failed with {type(e).__name__}: {e}") + logger.info("Attempting automatic worker recycle and retry") + # automatic one-time recycle on a fresh worker + self._recycle() + assert self._wp is not None + if call == "get": + out = self._wp.get(payload["name"]) + elif call == "set": + out = self._wp.set(payload["name"], payload["value"]) + else: + out = self._wp.call( + payload["name"], payload.get("args", ()), payload.get("kwargs", {}) + ) + self._state.calls_ok += 1 + logger.info( + f"RPC {call} succeeded after recycle, total calls: {self._state.calls_ok}" + ) + return out + + # ------------------------ Attribute proxying magic ------------------------ # + + def __getattr__(self, name: str): + # Filter out IPython-specific attributes to prevent infinite loops + if name.startswith('_ipython_') or name in { + '_ipython_canary_method_should_not_exist_', + '_repr_mimebundle_', + '_repr_html_', + '_repr_json_', + '_repr_latex_', + '_repr_png_', + '_repr_jpeg_', + '_repr_svg_', + '_repr_pdf_', + }: + logger.debug(f"Filtering out IPython attribute: {name}") + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + + def _remote_method(*args, **kwargs): + return self._rpc("call", name=name, args=args, kwargs=kwargs) + + # Try a GET first; if it errors, assume it's a method + try: + val = self._rpc("get", name=name) + except Tempo2Error: + return _remote_method + if callable(val): + return _remote_method + return val + + def __setattr__(self, name: str, value: Any): + if name in tempopulsar.__slots__: + return object.__setattr__(self, name, value) + _ = self._rpc("set", name=name, value=value) + return None + + # Explicit helpers for common call shapes + def residuals(self, **kwargs): + return self._rpc("call", name="residuals", kwargs=kwargs) + + def designmatrix(self, **kwargs): + return self._rpc("call", name="designmatrix", kwargs=kwargs) + + def toas(self, **kwargs): + return self._rpc("call", name="toas", kwargs=kwargs) + + def fit(self, **kwargs): + return self._rpc("call", name="fit", kwargs=kwargs) + + def __del__(self): + with contextlib.suppress(Exception): + if self._wp is not None: + self._wp.close() + + +# -------------------------- Bulk loader (optional) -------------------------- # + + +@dataclass +class LoadReport: + par: str + tim: Optional[str] + attempts: int + ok: bool + error: Optional[str] = None + retried: bool = False + + +def load_many( + pairs: Iterable[Tuple[str, Optional[str]]], + policy: Optional[Policy] = None, + parallel: int = 8, +) -> Tuple[Dict[str, tempopulsar], Dict[str, LoadReport], List[LoadReport]]: + """ + Bulk-load many pulsars with bounded parallelism. + Returns: (ok_by_name, retried_by_name, failed_list) + + ok_by_name: {psr_name: tempopulsar proxy} + retried_by_name: {psr_name: LoadReport} (those that required >=1 retry) + failed_list: [LoadReport,...] + """ + pol = policy if isinstance(policy, Policy) else Policy() + logger.info( + f"Starting bulk load of {len(list(pairs))} pulsars with {parallel} parallel workers" + ) + logger.info( + f"Using policy: ctor_retry={pol.ctor_retry}, ctor_backoff={pol.ctor_backoff}s" + ) + + def _one(par, tim): + """Load a single pulsar with retry logic for bulk loading.""" + logger.debug(f"Loading pulsar: par={par}, tim={tim}") + attempts = 0 + report = LoadReport(par=par, tim=tim, attempts=0, ok=False) + last_exc = None + for _ in range(1 + pol.ctor_retry): + attempts += 1 + try: + psr = tempopulsar(parfile=par, timfile=tim, policy=pol) + name = getattr(psr, "name") + report.attempts = attempts + report.ok = True + report.retried = attempts > 1 + logger.info(f"Successfully loaded {name} in {attempts} attempt(s)") + return ("ok", name, psr, report) + except Exception as e: + logger.warning(f"Failed to load {par} (attempt {attempts}): {e}") + last_exc = e + time.sleep(pol.ctor_backoff) + report.attempts = attempts + report.ok = False + report.error = f"{last_exc.__class__.__name__}: {last_exc}" + logger.error(f"Failed to load {par} after {attempts} attempts: {last_exc}") + return ("fail", None, None, report) + + ok: Dict[str, tempopulsar] = {} + retried: Dict[str, LoadReport] = {} + failed: List[LoadReport] = [] + + with ThreadPoolExecutor(max_workers=max(1, parallel)) as ex: + futs = {ex.submit(_one, par, tim): (par, tim) for (par, tim) in pairs} + for fut in as_completed(futs): + kind, name, psr, report = fut.result() + if kind == "ok": + ok[name] = psr + if report.retried: + retried[name] = report + else: + failed.append(report) + + logger.info( + f"Bulk load completed: {len(ok)} successful, {len(retried)} retried, {len(failed)} failed" + ) + return ok, retried, failed + + +# ------------------------------- Quick helpers ------------------------------ # + + +def configure_logging( + level: str = "INFO", log_file: Optional[str] = None, enable_console: bool = True +): + """ + Configure standard logging for the sandbox. + + Args: + level: Log level ("DEBUG", "INFO", "WARNING", "ERROR") + log_file: Optional file path to log to + enable_console: Whether to log to console + """ + # Get the sandbox logger + sandbox_logger = logging.getLogger(__name__) + + # Clear existing handlers + sandbox_logger.handlers.clear() + + # Set level + numeric_level = getattr(logging, level.upper(), logging.INFO) + sandbox_logger.setLevel(numeric_level) + + # Create formatter + formatter = logging.Formatter( + '%(asctime)s | %(levelname)-8s | tempo2_sandbox | %(message)s', + datefmt='%H:%M:%S' + ) + + # Add console handler if requested + if enable_console: + console_handler = logging.StreamHandler(sys.stderr) + console_handler.setFormatter(formatter) + sandbox_logger.addHandler(console_handler) + + # Add file handler if requested + if log_file: + file_handler = logging.FileHandler(log_file) + file_handler.setFormatter(formatter) + sandbox_logger.addHandler(file_handler) + + logger.info( + f"Logging configured: level={level}, console={enable_console}, file={log_file}" + ) + + +def setup_instructions(env_name: str = "tempo2_intel"): + """Print setup instructions for creating a tempo2 environment. + + This is a utility function to help users set up their environment + for using the sandbox with different Python environments. + """ + print("Setup instructions for environment '{}':".format(env_name)) + print("\n1. Conda (recommended):") + print(f" conda create -n {env_name} python=3.11") + print(f" conda activate {env_name}") + print(" conda install -c conda-forge tempo2 libstempo") + print(f' # then just: psr = tempopulsar(..., env_name="{env_name}")') + print("\n2. Virtual Environment (Rosetta):") + print(f" arch -x86_64 /usr/local/bin/python3 -m venv ~/.venvs/{env_name}") + print(f" source ~/.venvs/{env_name}/bin/activate") + print(" pip install tempo2 libstempo") + print(f' # then just: psr = tempopulsar(..., env_name="{env_name}")') + print("\n3. System Python with Rosetta:") + print(" # Install Intel Python first (or use system one under arch).") + print( + ' # You can force Rosetta via TEMPO2_SANDBOX_WORKER_ARCH_PREFIX="arch -x86_64"' + ) + print(' # then: psr = tempopulsar(..., env_name="arch")') + + +def detect_and_guide(env_name: str): + """Detect environment type and provide guidance for setup. + + This is a utility function to help users understand what type + of environment they have and how to use it with the sandbox. + """ + et = _detect_environment_type(env_name) + print(f"Environment detection for '{env_name}': {et}") + if et.startswith("conda:"): + print("✅ Conda env detected; just use env_name as given.") + elif et == "venv": + p = _find_venv_python_path(env_name) + if p: + print(f"✅ venv detected at {p}") + else: + print("❌ venv name matched, but python path not resolved.") + elif et == "arch": + print("✅ Rosetta/system arch mode will be used.") + elif et == "python": + print("✅ Using explicit Python path.") + else: + print( + "❌ Not found. Use conda env name, venv name, 'arch', or 'python:/abs/python'." + ) + + +# ------------------------------ Module runner ------------------------------- # + +if __name__ == "__main__": + # If executed directly, act as worker (useful for manual debugging): + _worker_stdio_main() diff --git a/libstempo/tim_file_analyzer.py b/libstempo/tim_file_analyzer.py new file mode 100644 index 0000000..d40ae36 --- /dev/null +++ b/libstempo/tim_file_analyzer.py @@ -0,0 +1,548 @@ +"""TimFileAnalyzer - Fast TIM file analyzer without PINT dependencies. + +This module provides a lightweight class to quickly extract TOA MJD values +from TIM files using independent parsing logic that replicates PINT's functionality +without requiring PINT as a dependency. This is useful for environments where +PINT is not available or when you need a minimal implementation. + +The implementation replicates the core functionality from PINT's toa.py module +for parsing TOA lines and handling TIM file commands. + +Author: Rutger van Haasteren -- rutger@vhaasteren.com +Date: 2025-10-10 +""" + +import re +import logging +from pathlib import Path +from typing import List, Set, Dict, Tuple, Optional + +logger = logging.getLogger(__name__) + + +# TOA commands that can appear in TIM files +TOA_COMMANDS = ( + "DITHER", + "EFAC", + "EMAX", + "EMAP", + "EMIN", + "EQUAD", + "FMAX", + "FMIN", + "INCLUDE", + "INFO", + "JUMP", + "MODE", + "NOSKIP", + "PHA1", + "PHA2", + "PHASE", + "SEARCH", + "SIGMA", + "SIM", + "SKIP", + "TIME", + "TRACK", + "ZAWGT", + "FORMAT", + "END", +) + +# Simple observatory name mapping (subset of PINT's observatory registry) +# This covers the most common observatories used in TOA files +OBSERVATORY_NAMES = { + # Common single-letter codes + "A": "Arecibo", + "AO": "Arecibo", + "ARECIBO": "Arecibo", + "B": "GBT", + "GBT": "GBT", + "GREEN_BANK": "GBT", + "C": "CHIME", + "CHIME": "CHIME", + "D": "DSS-43", + "E": "Effelsberg", + "EFFELSBERG": "Effelsberg", + "F": "FAST", + "FAST": "FAST", + "G": "GMRT", + "GMRT": "GMRT", + "H": "Hobart", + "HOBART": "Hobart", + "I": "IAR", + "IAR": "IAR", + "J": "Jodrell", + "JB": "Jodrell", + "JODRELL": "Jodrell", + "K": "Kalyazin", + "KALYAZIN": "Kalyazin", + "L": "Lovell", + "LOVELL": "Lovell", + "M": "MeerKAT", + "MEERKAT": "MeerKAT", + "N": "Nancay", + "NANCAY": "Nancay", + "O": "Ooty", + "OOTY": "Ooty", + "P": "Parkes", + "PARKES": "Parkes", + "Q": "Qitai", + "QITAI": "Qitai", + "R": "RATAN", + "RATAN": "RATAN", + "S": "Sardinia", + "SARDINIA": "Sardinia", + "T": "Tianma", + "TIANMA": "Tianma", + "U": "URUMQI", + "URUMQI": "URUMQI", + "V": "VLA", + "VLA": "VLA", + "W": "Westerbork", + "WESTERBORK": "Westerbork", + "X": "Xinjiang", + "XINJIANG": "Xinjiang", + "Y": "Yunnan", + "YUNNAN": "Yunnan", + "Z": "Zelenchukskaya", + "ZELENCHUKSKAYA": "Zelenchukskaya", + # Special codes + "@": "Barycenter", + "BARYCENTER": "Barycenter", + "SSB": "Barycenter", + # Spacecraft observatories + "FERMI": "Fermi", + "NICER": "NICER", + "SWIFT": "Swift", + "RXTE": "RXTE", + "XTE": "RXTE", +} + + +def _toa_format(line: str, fmt: str = "Unknown") -> str: + """Determine the type of a TOA line. + + Identifies a TOA line as one of the following types: + Comment, Command, Blank, Tempo2, Princeton, ITOA, Parkes, Unknown. + + This replicates PINT's _toa_format function. + """ + # Check for comments first + if ( + line.startswith("C ") and len(line) > 2 and not line[2].isdigit() # C followed by non-digit + or line.startswith("c ") and len(line) > 2 and not line[2].isdigit() # c followed by non-digit + or line.startswith("#") + or line.startswith("CC ") + ): + return "Comment" + + # Check for commands + if line.upper().lstrip().startswith(TOA_COMMANDS): + return "Command" + + # Check for blank lines + if re.match(r"^\s*$", line): + return "Blank" + + # Check for Princeton format: starts with single observatory code followed by space + if re.match(r"[0-9a-zA-Z@] ", line): + return "Princeton" + + # Check for Tempo2 format: long lines, explicitly marked as Tempo2, or structured like Tempo2 + if len(line) > 80 or fmt == "Tempo2": + return "Tempo2" + + # Additional Tempo2 detection: if it looks like a Tempo2 TOA line (has 5+ space-separated fields) + fields = line.split() + if len(fields) >= 5: + # Check if it looks like a Tempo2 TOA line: name freq mjd error obs [flags...] + try: + # Try to parse as Tempo2: name freq mjd error obs + float(fields[1]) # frequency should be numeric + float(fields[2]) # MJD should be numeric + float(fields[3]) # error should be numeric + # If we get here, it looks like Tempo2 format + return "Tempo2" + except (ValueError, IndexError): + pass + + # Check for Parkes format: starts with space, has decimal at position 42 + if re.match(r"^ ", line) and len(line) > 41 and line[41] == ".": + return "Parkes" + + # Check for ITOA format: two non-space chars, decimal at position 15 + if re.match(r"\S\S", line) and len(line) > 14 and line[14] == ".": + return "ITOA" + + # Default to Unknown + return "Unknown" + + +def _get_observatory_name(obs_code: str) -> str: + """Get observatory name from observatory code. + + This is a simplified version of PINT's get_observatory function + that only handles the most common observatory codes. + + Args: + obs_code: Observatory code (e.g., 'A', 'AO', '@') + + Returns: + Observatory name + """ + obs_code_upper = obs_code.upper() + return OBSERVATORY_NAMES.get(obs_code_upper, obs_code_upper) + + +def _parse_TOA_line(line: str, fmt: str = "Unknown") -> Tuple[Optional[Tuple[int, float]], dict]: + """Parse a one-line ASCII time-of-arrival. + + Return an MJD tuple and a dictionary of other TOA information. + The format can be one of: Comment, Command, Blank, Tempo2, + Princeton, ITOA, Parkes, or Unknown. + + This replicates PINT's _parse_TOA_line function. + """ + MJD = None + fmt = _toa_format(line, fmt) + d = dict(format=fmt) + + if fmt == "Princeton": + # Princeton format + # ---------------- + # columns item + # 1-1 Observatory (one-character code) '@' is barycenter + # 2-2 must be blank + # 16-24 Observing frequency (MHz) + # 25-44 TOA (decimal point must be in column 30 or column 31) + # 45-53 TOA uncertainty (microseconds) + # 69-78 DM correction (pc cm^-3) + try: + # Handle both fixed-width and space-separated Princeton format + if len(line) >= 78: # Fixed-width format + d["obs"] = _get_observatory_name(line[0].upper()) + d["freq"] = float(line[15:24]) + d["error"] = float(line[44:53]) + ii, ff = line[24:44].split(".") + ii = int(ii) + # For very old TOAs, see https://tempo.sourceforge.net/ref_man_sections/toa.txt + if ii < 40000: + ii += 39126 + MJD = (ii, float(f"0.{ff}")) + try: + d["ddm"] = str(float(line[68:78])) + except ValueError: + d["ddm"] = str(0.0) + else: # Space-separated format (fallback) + fields = line.split() + if len(fields) >= 4: + d["obs"] = _get_observatory_name(fields[0].upper()) + d["freq"] = float(fields[1]) + d["error"] = float(fields[3]) + # Parse MJD + if "." in fields[2]: + ii, ff = fields[2].split(".") + ii = int(ii) + if ii < 40000: + ii += 39126 + MJD = (ii, float(f"0.{ff}")) + else: + ii = int(fields[2]) + if ii < 40000: + ii += 39126 + MJD = (ii, 0.0) + d["ddm"] = str(0.0) # Default DM correction + else: + raise ValueError("Not enough fields for Princeton format") + except (ValueError, IndexError) as e: + # If parsing fails, treat as unknown format + logger.debug(f"Failed to parse Princeton format line: {e}") + d["format"] = "Unknown" + + elif fmt == "Tempo2": + # This could use more error catching... + try: + fields = line.split() + d["name"] = fields[0] + d["freq"] = float(fields[1]) + if "." in fields[2]: + ii, ff = fields[2].split(".") + MJD = (int(ii), float(f"0.{ff}")) + else: + MJD = (int(fields[2]), 0.0) + d["error"] = float(fields[3]) + d["obs"] = _get_observatory_name(fields[4].upper()) + # All the rest should be flags + flags = fields[5:] + + # Flags and flag-values should be given in pairs. + # The for loop below will fail otherwise. + if len(flags) % 2 != 0: + raise ValueError( + f"Flags and flag-values should be given in pairs. The given flags are {' '.join(flags)}" + ) + + for i in range(0, len(flags), 2): + k, v = flags[i].lstrip("-"), flags[i + 1] + if k in ["error", "freq", "scale", "MJD", "flags", "obs", "name"]: + raise ValueError(f"TOA flag ({k}) will overwrite TOA parameter!") + if not k: + raise ValueError(f"The string {repr(flags[i])} is not a valid flag") + d[k] = v + except (ValueError, IndexError) as e: + # If parsing fails, treat as unknown format + logger.debug(f"Failed to parse Tempo2 format line: {e}") + d["format"] = "Unknown" + + elif fmt == "Command": + d[fmt] = line.split() + elif fmt == "Parkes": + """ + columns item + 1-1 Must be blank + 26-34 Observing Frequency (MHz) + 35-55 TOA (decimal point must be in column 42) + 56-63 Phase offset (fraction of P0, added to TOA) + 64-71 TOA uncertainty + 80-80 Observatory (1 character) + """ + try: + d["name"] = line[1:25] + d["freq"] = float(line[25:34]) + ii = line[34:41] + ff = line[42:55] + MJD = int(ii), float(f"0.{ff}") + phaseoffset = float(line[55:62]) + if phaseoffset != 0: + raise ValueError( + f"Cannot interpret Parkes format with phaseoffset={phaseoffset} yet" + ) + d["error"] = float(line[63:71]) + d["obs"] = _get_observatory_name(line[79].upper()) + except (ValueError, IndexError) as e: + # If parsing fails, treat as unknown format + logger.debug(f"Failed to parse Parkes format line: {e}") + d["format"] = "Unknown" + + elif fmt == "ITOA": + raise RuntimeError(f"TOA format '{fmt}' not implemented yet") + elif fmt not in ["Blank", "Comment"]: + raise RuntimeError( + f"Unable to identify TOA format for line {line!r}, expecting {fmt}" + ) + return MJD, d + + +class TimFileAnalyzer: + """Fast TIM file analyzer for timespan calculation without PINT dependencies. + + This class efficiently extracts TOA MJD values from TIM files using + independent parsing logic that replicates PINT's functionality, + providing both performance and robustness for timespan calculations. + + The analyzer caches results per file to avoid duplicate parsing when both + timespan and TOA count are needed for the same file. + """ + + def __init__(self): + """Initialize the TIM file analyzer.""" + self.logger = logger + self._processed_files: Set[Path] = set() + # Cache for storing timespan and TOA counts per file (not MJD values to save memory) + self._file_cache: Dict[Path, Tuple[float, int]] = {} + + def _get_timespan_and_count(self, tim_file_path: Path) -> Tuple[float, int]: + """Get timespan and TOA count from TIM file, using cache if available. + + Args: + tim_file_path: Path to the TIM file + + Returns: + Tuple of (timespan_in_days, toa_count) + """ + # Check cache first + if tim_file_path in self._file_cache: + self.logger.debug(f"Using cached data for {tim_file_path}") + return self._file_cache[tim_file_path] + + try: + self._processed_files.clear() + mjd_values = self._extract_mjd_values_recursive(tim_file_path) + toa_count = len(mjd_values) + + if toa_count == 0: + timespan = 0.0 + else: + # Calculate timespan as max - min (no need to sort) + timespan = float(max(mjd_values) - min(mjd_values)) + + # Cache only the results we need (timespan and count) + self._file_cache[tim_file_path] = (timespan, toa_count) + + if toa_count > 0: + self.logger.debug( + f"Cached data for {tim_file_path}: {timespan:.1f} days, {toa_count} TOAs" + ) + else: + self.logger.debug(f"Cached data for {tim_file_path}: No TOAs found") + return timespan, toa_count + + except Exception as e: + self.logger.warning(f"Parsing failed for {tim_file_path}: {e}") + self.logger.debug( + "File may contain non-standard TIM format or malformed data" + ) + # Cache empty result to avoid repeated failures + empty_result = (0.0, 0) + self._file_cache[tim_file_path] = empty_result + return empty_result + + def calculate_timespan(self, tim_file_path: Path) -> float: + """Calculate timespan from TIM file using independent parsing logic. + + Args: + tim_file_path: Path to the TIM file + + Returns: + Timespan in days (max(mjd) - min(mjd)) + """ + timespan, toa_count = self._get_timespan_and_count(tim_file_path) + + if toa_count == 0: + self.logger.warning(f"No TOAs found in {tim_file_path}") + return 0.0 + + self.logger.debug( + f"Timespan for {tim_file_path}: {timespan:.1f} days ({toa_count} TOAs)" + ) + return timespan + + def count_toas(self, tim_file_path: Path) -> int: + """Count the number of TOAs in a TIM file. + + Args: + tim_file_path: Path to the TIM file + + Returns: + Number of TOAs found in the file + """ + _, toa_count = self._get_timespan_and_count(tim_file_path) + + self.logger.debug( + f"TOA count for {tim_file_path}: {toa_count} TOAs" + ) + return toa_count + + def clear_cache(self) -> None: + """Clear the file cache.""" + self._file_cache.clear() + self.logger.debug("File cache cleared") + + def get_timespan_and_count(self, tim_file_path: Path) -> Tuple[float, int]: + """Get both timespan and TOA count efficiently using cached data. + + Args: + tim_file_path: Path to the TIM file + + Returns: + Tuple of (timespan_in_days, toa_count) + """ + timespan, toa_count = self._get_timespan_and_count(tim_file_path) + + if toa_count == 0: + self.logger.warning(f"No TOAs found in {tim_file_path}") + return 0.0, 0 + + self.logger.debug( + f"Timespan and count for {tim_file_path}: {timespan:.1f} days, {toa_count} TOAs" + ) + return timespan, toa_count + + def _extract_mjd_values_recursive(self, tim_file_path: Path) -> List[float]: + """Recursively extract MJD values from TIM file and included files. + + Args: + tim_file_path: Path to the TIM file + + Returns: + List of MJD values from all TOA lines + """ + mjd_values = [] + + # Avoid infinite recursion + if tim_file_path in self._processed_files: + self.logger.warning(f"Circular INCLUDE detected: {tim_file_path}") + return mjd_values + + self._processed_files.add(tim_file_path) + + try: + with open(tim_file_path, "r") as f: + for line_num, line in enumerate(f, 1): + line = line.strip() + + # Skip empty lines + if not line: + continue + + # Use independent parsing for both TOA lines and commands + try: + mjd_tuple, d = _parse_TOA_line(line) + except Exception as e: + # Parsing may fail on malformed lines - skip them gracefully + self.logger.debug( + f"Skipping malformed line in {tim_file_path}: {line.strip()} - {e}" + ) + continue + + # Handle commands (especially INCLUDE) + if d["format"] == "Command": + self._handle_command(d, tim_file_path, mjd_values) + continue + + # Skip non-TOA lines + if d["format"] in ("Comment", "Blank", "Unknown"): + continue + + # Extract MJD from TOA line + if mjd_tuple is not None: + # Convert (int, float) tuple to float MJD + mjd_value = float(mjd_tuple[0]) + float(mjd_tuple[1]) + mjd_values.append(mjd_value) + + except Exception as e: + self.logger.error(f"Error reading TIM file {tim_file_path}: {e}") + + return mjd_values + + def _handle_command( + self, d: dict, current_file: Path, mjd_values: List[float] + ) -> None: + """Handle TIM file commands using parsed command data. + + Args: + d: Parsed command dictionary + current_file: Current TIM file being processed + mjd_values: List to extend with MJD values from included files + """ + if d["format"] != "Command": + return + + cmd = d["Command"][0].upper() + + if cmd == "INCLUDE": + if len(d["Command"]) < 2: + self.logger.warning(f"INCLUDE command without filename: {d['Command']}") + return + + include_file = d["Command"][1] + include_path = current_file.parent / include_file + + if include_path.exists(): + self.logger.debug(f"Processing included TOA file {include_path}") + included_mjds = self._extract_mjd_values_recursive(include_path) + mjd_values.extend(included_mjds) + else: + self.logger.warning(f"INCLUDE file not found: {include_path}") + # Other commands don't affect timespan calculation diff --git a/tests/test_imports.py b/tests/test_imports.py index 90f76ca..ba964c6 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -6,3 +6,5 @@ def test_imports(): import libstempo.toasim # noqa:F401 import libstempo.eccUtils # noqa: F401 import libstempo.spharmORFbasis # noqa: F401 + import libstempo.sandbox # noqa: F401 + import libstempo.tim_file_analyzer # noqa: F401 diff --git a/tests/test_sandbox.py b/tests/test_sandbox.py new file mode 100644 index 0000000..3057fc2 --- /dev/null +++ b/tests/test_sandbox.py @@ -0,0 +1,96 @@ +import unittest +import libstempo as t2 +from libstempo.sandbox import tempopulsar, Policy, configure_logging +from libstempo.tim_file_analyzer import TimFileAnalyzer + + +class TestSandbox(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.data_path = t2.__path__[0] + "/data/" + cls.parfile = cls.data_path + "J1909-3744_NANOGrav_dfg+12.par" + cls.timfile = cls.data_path + "J1909-3744_NANOGrav_dfg+12.tim" + + def test_basic_sandbox_usage(self): + """Test basic sandbox functionality""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile) + self.assertEqual(psr.name, "1909-3744") + self.assertEqual(psr.nobs, 1001) + + def test_policy_configuration(self): + """Test Policy configuration""" + policy = Policy(ctor_retry=2, call_timeout_s=30.0) + psr = tempopulsar( + parfile=self.parfile, timfile=self.timfile, policy=policy + ) + self.assertEqual(psr.name, "1909-3744") + + def test_designmatrix_call(self): + """Test calling designmatrix through sandbox""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile) + dmat = psr.designmatrix() + self.assertEqual(dmat.shape, (1001, 83)) + + def test_attribute_access(self): + """Test accessing attributes through sandbox""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile) + self.assertEqual(psr.name, "1909-3744") + self.assertEqual(psr.nobs, 1001) + self.assertEqual(len(psr.stoas), 1001) + + def test_logging_configuration(self): + """Test logging configuration""" + # This should not raise an exception + configure_logging(level="DEBUG", enable_console=False) + configure_logging(level="INFO", enable_console=True) + + +class TestTimFileAnalyzer(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.data_path = t2.__path__[0] + "/data/" + cls.timfile = cls.data_path + "J1909-3744_NANOGrav_dfg+12.tim" + + def test_toa_count(self): + """Test TOA counting functionality""" + analyzer = TimFileAnalyzer() + count = analyzer.count_toas(self.timfile) + self.assertEqual(count, 1001) + + def test_timespan_calculation(self): + """Test timespan calculation""" + analyzer = TimFileAnalyzer() + timespan = analyzer.calculate_timespan(self.timfile) + self.assertGreater(timespan, 0) + self.assertIsInstance(timespan, float) + + def test_combined_analysis(self): + """Test getting both timespan and count""" + analyzer = TimFileAnalyzer() + timespan, count = analyzer.get_timespan_and_count(self.timfile) + self.assertEqual(count, 1001) + self.assertGreater(timespan, 0) + self.assertIsInstance(timespan, float) + + def test_cache_functionality(self): + """Test that caching works correctly""" + analyzer = TimFileAnalyzer() + + # First call + timespan1, count1 = analyzer.get_timespan_and_count(self.timfile) + + # Second call should use cache + timespan2, count2 = analyzer.get_timespan_and_count(self.timfile) + + self.assertEqual(timespan1, timespan2) + self.assertEqual(count1, count2) + + # Clear cache and test again + analyzer.clear_cache() + timespan3, count3 = analyzer.get_timespan_and_count(self.timfile) + self.assertEqual(timespan1, timespan3) + self.assertEqual(count1, count3) + + +if __name__ == "__main__": + unittest.main() From b3549a3558aaedb60bb9ae182ce229b3494ff595 Mon Sep 17 00:00:00 2001 From: Rutger van Haasteren Date: Fri, 10 Oct 2025 17:35:43 +0000 Subject: [PATCH 02/16] Fix linting issues in sandbox implementation --- libstempo/sandbox.py | 241 ++++++++++++++---------------------------- tests/test_sandbox.py | 12 +-- 2 files changed, 82 insertions(+), 171 deletions(-) diff --git a/libstempo/sandbox.py b/libstempo/sandbox.py index e7c53b1..5dfca09 100644 --- a/libstempo/sandbox.py +++ b/libstempo/sandbox.py @@ -126,18 +126,12 @@ class Policy: preload_fit: bool = False # call fit() once after ctor # RPC protection - call_timeout_s: Optional[float] = ( - None # per-call timeout (seconds), None = no timeout - ) + call_timeout_s: Optional[float] = None # per-call timeout (seconds), None = no timeout kill_grace_s: float = 2.0 # after timeout, wait before SIGKILL # Recycling / hygiene - max_calls_per_worker: Optional[int] = ( - None # recycle after this many good calls, None = never recycle by calls - ) - max_age_s: Optional[float] = ( - None # recycle after this many seconds, None = never recycle by age - ) + max_calls_per_worker: Optional[int] = None # recycle after this many good calls, None = never recycle by calls + max_age_s: Optional[float] = None # recycle after this many seconds, None = never recycle by age rss_soft_limit_mb: Optional[int] = None # if provided, recycle when beaten # Proactive TOA handling for large files @@ -219,16 +213,14 @@ def _worker_stdio_main() -> None: try: try: from libstempo import tempopulsar as _lib_tempopulsar # noqa - import numpy as _np # noqa + import numpy # noqa hello["hello"]["has_libstempo"] = True # best-effort tempo2 version probe try: from libstempo import tempo2 # type: ignore - hello["hello"]["tempo2_version"] = getattr( - tempo2, "TEMPO2_VERSION", None - ) + hello["hello"]["tempo2_version"] = getattr(tempo2, "TEMPO2_VERSION", None) except Exception: pass except Exception: @@ -240,11 +232,10 @@ def _worker_stdio_main() -> None: # If libstempo failed to import at hello, try once more here to return clean errors try: from libstempo import tempopulsar as _lib_tempopulsar # noqa - import numpy as _np # noqa + import numpy # noqa except Exception: # Keep serving, but report on first request _lib_tempopulsar: Optional[Any] = None - _np: Optional[Any] = None obj = None @@ -299,16 +290,12 @@ def _write_response(resp: Dict[str, Any]) -> None: # Handle methods try: if method == "bye": - _write_response( - {"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py("bye")} - ) + _write_response({"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py("bye")}) return if method == "rss": rss = _current_rss_mb_portable() - _write_response( - {"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py(rss)} - ) + _write_response({"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py(rss)}) continue if method == "ctor": @@ -365,17 +352,13 @@ def _write_response(resp: Dict[str, Any]) -> None: val = val.copy() except Exception: pass - _write_response( - {"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py(val)} - ) + _write_response({"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py(val)}) continue if method == "set": name, value = params["name"], params["value"] setattr(obj, name, value) - _write_response( - {"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py(None)} - ) + _write_response({"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py(None)}) continue if method == "call": @@ -391,9 +374,7 @@ def _write_response(resp: Dict[str, Any]) -> None: out = out.copy() except Exception: pass - _write_response( - {"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py(out)} - ) + _write_response({"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py(out)}) continue if method == "del": @@ -402,9 +383,7 @@ def _write_response(resp: Dict[str, Any]) -> None: except Exception: pass obj = None - _write_response( - {"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py(None)} - ) + _write_response({"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py(None)}) continue _write_response( @@ -453,9 +432,7 @@ def _start(self, require_x86_64: bool = False): env = os.environ.copy() env.setdefault("PYTHONUNBUFFERED", "1") - logger.debug( - f"Launching subprocess with environment: PYTHONUNBUFFERED={env.get('PYTHONUNBUFFERED')}" - ) + logger.debug(f"Launching subprocess with environment: PYTHONUNBUFFERED={env.get('PYTHONUNBUFFERED')}") logger.debug(f"Subprocess working directory: {os.getcwd()}") self.proc = subprocess.Popen( self.cmd, @@ -494,13 +471,9 @@ def _start(self, require_x86_64: bool = False): if require_x86_64: if str(info.get("machine", "")).lower() != "x86_64": - logger.error( - f"Architecture mismatch: worker is {info.get('machine')}, but x86_64 required" - ) + logger.error(f"Architecture mismatch: worker is {info.get('machine')}, but x86_64 required") self._hard_kill() - raise Tempo2Error( - f"worker arch is {info.get('machine')}, but x86_64 is required for quad precision" - ) + raise Tempo2Error(f"worker arch is {info.get('machine')}, but x86_64 is required for quad precision") if not info.get("has_libstempo", False): logger.error("libstempo not available in worker environment") @@ -533,9 +506,7 @@ def _readline_with_timeout(self, timeout: Optional[float]) -> Optional[str]: # With timeout end = time.time() + timeout while time.time() < end: - rlist, _, _ = select.select( - [self.proc.stdout], [], [], max(0.01, end - time.time()) - ) + rlist, _, _ = select.select([self.proc.stdout], [], [], max(0.01, end - time.time())) if rlist: line = self.proc.stdout.readline() if not line: # EOF @@ -552,15 +523,10 @@ def _hard_kill(self): logger.warning(f"Failed to terminate process: {e}") pass t0 = time.time() - while ( - self.proc.poll() is None - and (time.time() - t0) < self.policy.kill_grace_s - ): + while self.proc.poll() is None and (time.time() - t0) < self.policy.kill_grace_s: time.sleep(0.01) if self.proc.poll() is None: - logger.warning( - f"Sending SIGKILL to worker process (PID: {self.proc.pid})" - ) + logger.warning(f"Sending SIGKILL to worker process (PID: {self.proc.pid})") with contextlib.suppress(Exception): os.kill(self.proc.pid, signal.SIGKILL) self.proc = None @@ -584,9 +550,7 @@ def __del__(self): # ---------- JSON-RPC helpers ---------- - def _send_rpc( - self, method: str, params: Dict[str, Any], timeout: Optional[float] = None - ) -> Any: + def _send_rpc(self, method: str, params: Dict[str, Any], timeout: Optional[float] = None) -> Any: if self.proc is None or self.proc.stdin is None or self.proc.stdout is None: logger.error("Worker not running, cannot send RPC") raise Tempo2Crashed("worker not running") @@ -636,13 +600,9 @@ def _send_rpc( raise Tempo2ProtocolError(f"malformed response: {resp_line!r}") if resp.get("id") != rid: - logger.error( - f"RPC {method} id mismatch: expected {rid}, got {resp.get('id')}" - ) + logger.error(f"RPC {method} id mismatch: expected {rid}, got {resp.get('id')}") self._hard_kill() - raise Tempo2ProtocolError( - f"mismatched id in response: {resp.get('id')} vs {rid}" - ) + raise Tempo2ProtocolError(f"mismatched id in response: {resp.get('id')} vs {rid}") if "error" in resp and resp["error"] is not None: err = resp["error"] @@ -659,9 +619,7 @@ def _send_rpc( def ctor(self, kwargs: Dict[str, Any], preload_residuals: bool): logger.info(f"Constructing tempopulsar with kwargs: {kwargs}") logger.info(f"Preload residuals: {preload_residuals}") - return self._send_rpc( - "ctor", {"kwargs": kwargs, "preload_residuals": preload_residuals} - ) + return self._send_rpc("ctor", {"kwargs": kwargs, "preload_residuals": preload_residuals}) def get(self, name: str): logger.debug(f"Getting attribute: {name}") @@ -673,9 +631,7 @@ def set(self, name: str, value: Any): def call(self, name: str, args=(), kwargs=None): logger.debug(f"Calling method: {name} with args={args}, kwargs={kwargs}") - return self._send_rpc( - "call", {"name": name, "args": tuple(args), "kwargs": dict(kwargs or {})} - ) + return self._send_rpc("call", {"name": name, "args": tuple(args), "kwargs": dict(kwargs or {})}) def rss(self) -> Optional[int]: try: @@ -767,9 +723,7 @@ def _resolve_worker_cmd(env_name: Optional[str]) -> Tuple[List[str], bool]: # Base invocation that runs this file in worker mode: # Find the src directory dynamically current_file = Path(__file__).resolve() - src_dir = ( - current_file.parent.parent - ) # Go up from libstempo/sandbox.py to src/ + src_dir = current_file.parent.parent # Go up from libstempo/sandbox.py to src/ src_path = str(src_dir) def python_to_worker_cmd(python_exe: str) -> List[str]: @@ -836,8 +790,7 @@ def python_to_worker_cmd(python_exe: str) -> List[str]: return (arch + python_to_worker_cmd(py), require_x86_64) raise Tempo2Error( - f"Environment '{env_name}' not found. " - "Use a conda env name, a venv name, 'arch', or 'python:/abs/python'." + f"Environment '{env_name}' not found. " "Use a conda env name, a venv name, 'arch', or 'python:/abs/python'." ) @@ -896,30 +849,22 @@ def __init__(self, env_name: Optional[str] = None, **kwargs): self._state = _State(created_at=time.time(), calls_ok=0) self._require_x86 = False - logger.info( - f"Creating tempopulsar with env_name='{env_name}', kwargs={self._ctor_kwargs}" - ) - logger.info( - f"Using policy: ctor_retry={self._policy.ctor_retry}, ctor_backoff={self._policy.ctor_backoff}s" - ) + logger.info(f"Creating tempopulsar with env_name='{env_name}', kwargs={self._ctor_kwargs}") + logger.info(f"Using policy: ctor_retry={self._policy.ctor_retry}, ctor_backoff={self._policy.ctor_backoff}s") self._construct_with_retries() # --------------- construction / reconstruction with retries --------------- # def _construct_with_retries(self): - logger.info( - f"Starting construction with {self._policy.ctor_retry + 1} total attempts" - ) - + logger.info(f"Starting construction with {self._policy.ctor_retry + 1} total attempts") + # Proactive TOA counting to avoid "Too many TOAs" errors if self._policy.auto_nobs_retry: self._proactive_nobs_setup() - + last_exc: Optional[Exception] = None for attempt in range(1 + self._policy.ctor_retry): - logger.info( - f"Construction attempt {attempt + 1}/{self._policy.ctor_retry + 1}" - ) + logger.info(f"Construction attempt {attempt + 1}/{self._policy.ctor_retry + 1}") try: cmd, require_x86 = _resolve_worker_cmd(self._env_name) self._require_x86 = require_x86 @@ -929,9 +874,7 @@ def _construct_with_retries(self): self._wp = _WorkerProc(self._policy, cmd, require_x86_64=require_x86) # ctor on the worker (libstempo.tempopulsar) logger.info("Calling constructor on worker...") - self._wp.ctor( - self._ctor_kwargs, preload_residuals=self._policy.preload_residuals - ) + self._wp.ctor(self._ctor_kwargs, preload_residuals=self._policy.preload_residuals) self._state.created_at = time.time() self._state.calls_ok = 0 logger.info(f"Construction successful on attempt {attempt + 1}") @@ -952,34 +895,36 @@ def _construct_with_retries(self): logger.info(f"Waiting {self._policy.ctor_backoff}s before retry...") time.sleep(self._policy.ctor_backoff) logger.error(f"All construction attempts failed. Last error: {last_exc}") - raise Tempo2ConstructorFailed( - f"tempopulsar ctor failed after retries: {last_exc}" - ) + raise Tempo2ConstructorFailed(f"tempopulsar ctor failed after retries: {last_exc}") def _proactive_nobs_setup(self): """Proactively count TOAs and add nobs parameter if needed to avoid 'Too many TOAs' errors.""" try: - timfile = self._ctor_kwargs.get('timfile') + timfile = self._ctor_kwargs.get("timfile") if not timfile: logger.debug("No timfile specified, skipping proactive nobs setup") return - + timfile_path = Path(timfile) if not timfile_path.exists(): logger.warning(f"TIM file does not exist: {timfile_path}") return - + logger.info(f"Proactively counting TOAs in {timfile_path}") analyzer = TimFileAnalyzer() toa_count = analyzer.count_toas(timfile_path) - + if toa_count > self._policy.nobs_threshold: maxobs_with_margin = int(toa_count * self._policy.nobs_safety_margin) - self._ctor_kwargs['maxobs'] = maxobs_with_margin - logger.info(f"Proactively added maxobs={maxobs_with_margin} parameter (TOAs: {toa_count}, threshold: {self._policy.nobs_threshold}, margin: {self._policy.nobs_safety_margin})") + self._ctor_kwargs["maxobs"] = maxobs_with_margin + logger.info( + f"Proactively added maxobs={maxobs_with_margin} parameter (TOAs: {toa_count}, threshold: {self._policy.nobs_threshold}, margin: {self._policy.nobs_safety_margin})" + ) else: - logger.debug(f"TOA count {toa_count} below threshold {self._policy.nobs_threshold}, no maxobs parameter needed") - + logger.debug( + f"TOA count {toa_count} below threshold {self._policy.nobs_threshold}, no maxobs parameter needed" + ) + except Exception as e: logger.warning(f"Proactive nobs setup failed: {e}") # Don't raise - this is just optimization, construction should still work @@ -995,16 +940,11 @@ def _should_recycle(self) -> bool: # Check age limit (if set) if self._policy.max_age_s is not None and age > self._policy.max_age_s: - logger.info( - f"Should recycle: worker age {age:.1f}s exceeds max_age_s {self._policy.max_age_s}" - ) + logger.info(f"Should recycle: worker age {age:.1f}s exceeds max_age_s {self._policy.max_age_s}") return True # Check call limit (if set) - if ( - self._policy.max_calls_per_worker is not None - and self._state.calls_ok >= self._policy.max_calls_per_worker - ): + if self._policy.max_calls_per_worker is not None and self._state.calls_ok >= self._policy.max_calls_per_worker: logger.info( f"Should recycle: calls_ok {self._state.calls_ok} exceeds max_calls_per_worker {self._policy.max_calls_per_worker}" ) @@ -1014,14 +954,10 @@ def _should_recycle(self) -> bool: if self._policy.rss_soft_limit_mb is not None: rss = self._wp.rss() if rss and rss > self._policy.rss_soft_limit_mb: - logger.info( - f"Should recycle: RSS {rss}MB exceeds limit {self._policy.rss_soft_limit_mb}MB" - ) + logger.info(f"Should recycle: RSS {rss}MB exceeds limit {self._policy.rss_soft_limit_mb}MB") return True - logger.debug( - f"Worker still healthy: age={age:.1f}s, calls={self._state.calls_ok}" - ) + logger.debug(f"Worker still healthy: age={age:.1f}s, calls={self._state.calls_ok}") return False def _recycle(self): @@ -1050,9 +986,7 @@ def _rpc(self, call: str, **payload): elif call == "set": out = self._wp.set(payload["name"], payload["value"]) elif call == "call": - out = self._wp.call( - payload["name"], payload.get("args", ()), payload.get("kwargs", {}) - ) + out = self._wp.call(payload["name"], payload.get("args", ()), payload.get("kwargs", {})) else: raise Tempo2ProtocolError(f"unknown call {call}") self._state.calls_ok += 1 @@ -1069,33 +1003,29 @@ def _rpc(self, call: str, **payload): elif call == "set": out = self._wp.set(payload["name"], payload["value"]) else: - out = self._wp.call( - payload["name"], payload.get("args", ()), payload.get("kwargs", {}) - ) + out = self._wp.call(payload["name"], payload.get("args", ()), payload.get("kwargs", {})) self._state.calls_ok += 1 - logger.info( - f"RPC {call} succeeded after recycle, total calls: {self._state.calls_ok}" - ) + logger.info(f"RPC {call} succeeded after recycle, total calls: {self._state.calls_ok}") return out # ------------------------ Attribute proxying magic ------------------------ # def __getattr__(self, name: str): # Filter out IPython-specific attributes to prevent infinite loops - if name.startswith('_ipython_') or name in { - '_ipython_canary_method_should_not_exist_', - '_repr_mimebundle_', - '_repr_html_', - '_repr_json_', - '_repr_latex_', - '_repr_png_', - '_repr_jpeg_', - '_repr_svg_', - '_repr_pdf_', + if name.startswith("_ipython_") or name in { + "_ipython_canary_method_should_not_exist_", + "_repr_mimebundle_", + "_repr_html_", + "_repr_json_", + "_repr_latex_", + "_repr_png_", + "_repr_jpeg_", + "_repr_svg_", + "_repr_pdf_", }: logger.debug(f"Filtering out IPython attribute: {name}") raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") - + def _remote_method(*args, **kwargs): return self._rpc("call", name=name, args=args, kwargs=kwargs) @@ -1160,12 +1090,8 @@ def load_many( failed_list: [LoadReport,...] """ pol = policy if isinstance(policy, Policy) else Policy() - logger.info( - f"Starting bulk load of {len(list(pairs))} pulsars with {parallel} parallel workers" - ) - logger.info( - f"Using policy: ctor_retry={pol.ctor_retry}, ctor_backoff={pol.ctor_backoff}s" - ) + logger.info(f"Starting bulk load of {len(list(pairs))} pulsars with {parallel} parallel workers") + logger.info(f"Using policy: ctor_retry={pol.ctor_retry}, ctor_backoff={pol.ctor_backoff}s") def _one(par, tim): """Load a single pulsar with retry logic for bulk loading.""" @@ -1208,18 +1134,14 @@ def _one(par, tim): else: failed.append(report) - logger.info( - f"Bulk load completed: {len(ok)} successful, {len(retried)} retried, {len(failed)} failed" - ) + logger.info(f"Bulk load completed: {len(ok)} successful, {len(retried)} retried, {len(failed)} failed") return ok, retried, failed # ------------------------------- Quick helpers ------------------------------ # -def configure_logging( - level: str = "INFO", log_file: Optional[str] = None, enable_console: bool = True -): +def configure_logging(level: str = "INFO", log_file: Optional[str] = None, enable_console: bool = True): """ Configure standard logging for the sandbox. @@ -1230,35 +1152,30 @@ def configure_logging( """ # Get the sandbox logger sandbox_logger = logging.getLogger(__name__) - + # Clear existing handlers sandbox_logger.handlers.clear() - + # Set level numeric_level = getattr(logging, level.upper(), logging.INFO) sandbox_logger.setLevel(numeric_level) - + # Create formatter - formatter = logging.Formatter( - '%(asctime)s | %(levelname)-8s | tempo2_sandbox | %(message)s', - datefmt='%H:%M:%S' - ) - + formatter = logging.Formatter("%(asctime)s | %(levelname)-8s | tempo2_sandbox | %(message)s", datefmt="%H:%M:%S") + # Add console handler if requested if enable_console: console_handler = logging.StreamHandler(sys.stderr) console_handler.setFormatter(formatter) sandbox_logger.addHandler(console_handler) - + # Add file handler if requested if log_file: file_handler = logging.FileHandler(log_file) file_handler.setFormatter(formatter) sandbox_logger.addHandler(file_handler) - - logger.info( - f"Logging configured: level={level}, console={enable_console}, file={log_file}" - ) + + logger.info(f"Logging configured: level={level}, console={enable_console}, file={log_file}") def setup_instructions(env_name: str = "tempo2_intel"): @@ -1280,9 +1197,7 @@ def setup_instructions(env_name: str = "tempo2_intel"): print(f' # then just: psr = tempopulsar(..., env_name="{env_name}")') print("\n3. System Python with Rosetta:") print(" # Install Intel Python first (or use system one under arch).") - print( - ' # You can force Rosetta via TEMPO2_SANDBOX_WORKER_ARCH_PREFIX="arch -x86_64"' - ) + print(' # You can force Rosetta via TEMPO2_SANDBOX_WORKER_ARCH_PREFIX="arch -x86_64"') print(' # then: psr = tempopulsar(..., env_name="arch")') @@ -1307,9 +1222,7 @@ def detect_and_guide(env_name: str): elif et == "python": print("✅ Using explicit Python path.") else: - print( - "❌ Not found. Use conda env name, venv name, 'arch', or 'python:/abs/python'." - ) + print("❌ Not found. Use conda env name, venv name, 'arch', or 'python:/abs/python'.") # ------------------------------ Module runner ------------------------------- # diff --git a/tests/test_sandbox.py b/tests/test_sandbox.py index 3057fc2..49256e5 100644 --- a/tests/test_sandbox.py +++ b/tests/test_sandbox.py @@ -20,9 +20,7 @@ def test_basic_sandbox_usage(self): def test_policy_configuration(self): """Test Policy configuration""" policy = Policy(ctor_retry=2, call_timeout_s=30.0) - psr = tempopulsar( - parfile=self.parfile, timfile=self.timfile, policy=policy - ) + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile, policy=policy) self.assertEqual(psr.name, "1909-3744") def test_designmatrix_call(self): @@ -75,16 +73,16 @@ def test_combined_analysis(self): def test_cache_functionality(self): """Test that caching works correctly""" analyzer = TimFileAnalyzer() - + # First call timespan1, count1 = analyzer.get_timespan_and_count(self.timfile) - + # Second call should use cache timespan2, count2 = analyzer.get_timespan_and_count(self.timfile) - + self.assertEqual(timespan1, timespan2) self.assertEqual(count1, count2) - + # Clear cache and test again analyzer.clear_cache() timespan3, count3 = analyzer.get_timespan_and_count(self.timfile) From 7f915923eaea0bc4cabad7575e7c98d60929ee8d Mon Sep 17 00:00:00 2001 From: Rutger van Haasteren Date: Fri, 10 Oct 2025 19:39:24 +0200 Subject: [PATCH 03/16] Changed some comments --- libstempo/sandbox.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/libstempo/sandbox.py b/libstempo/sandbox.py index 5dfca09..5e622e5 100644 --- a/libstempo/sandbox.py +++ b/libstempo/sandbox.py @@ -1,5 +1,7 @@ -# sandbox.py """ +Author: Rutger van Haasteren -- rutger@vhaasteren.com +Date: 2025-10-10 + Process sandbox for libstempo/tempo2 that keeps each pulsar in its own clean subprocess. A segfault in tempo2/libstempo only kills the worker, not your kernel. From 506f1983cda2d155acf2f1537a39b8b654188bde Mon Sep 17 00:00:00 2001 From: Rutger van Haasteren Date: Fri, 10 Oct 2025 18:44:43 +0000 Subject: [PATCH 04/16] Fix pyproject.toml license format for PEP 621 compliance --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 091ce44..1b86486 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ description = "A Python wrapper for tempo2" authors = [{name = "Michele Vallisneri", email = "vallis@vallis.org"}] urls = { Homepage = "https://github.com/vallis/libstempo" } readme = "README.md" -license = "MIT" +license = {text = "MIT"} license-files = [ "LICENSE" ] classifiers=[ "Intended Audience :: Developers", From 456d83f29588589d08002d99b3080a8088cf07b4 Mon Sep 17 00:00:00 2001 From: Rutger van Haasteren Date: Sat, 11 Oct 2025 06:01:19 +0000 Subject: [PATCH 05/16] Fix pyproject.toml license to use file-based format --- pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1b86486..73822d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,8 +25,7 @@ description = "A Python wrapper for tempo2" authors = [{name = "Michele Vallisneri", email = "vallis@vallis.org"}] urls = { Homepage = "https://github.com/vallis/libstempo" } readme = "README.md" -license = {text = "MIT"} -license-files = [ "LICENSE" ] +license = {file = "LICENSE"} classifiers=[ "Intended Audience :: Developers", "Intended Audience :: Science/Research", From 112c340283b627840ade272fbed3dc1a8a0079ed Mon Sep 17 00:00:00 2001 From: Rutger van Haasteren Date: Sat, 11 Oct 2025 06:12:34 +0000 Subject: [PATCH 06/16] Fix linting issues: line length and unused imports - Break long lines in sandbox.py to fit 120-char limit - Add noqa comments for imports in __init__.py - Format tim_file_analyzer.py with black --- libstempo/__init__.py | 4 +- libstempo/sandbox.py | 7 ++- libstempo/tim_file_analyzer.py | 90 ++++++++++++++-------------------- 3 files changed, 45 insertions(+), 56 deletions(-) diff --git a/libstempo/__init__.py b/libstempo/__init__.py index 8969fae..519801f 100644 --- a/libstempo/__init__.py +++ b/libstempo/__init__.py @@ -2,8 +2,8 @@ from ._find_tempo2 import find_tempo2_runtime # Import sandbox functionality -from .sandbox import tempopulsar as sandbox_tempopulsar, Policy, configure_logging -from .tim_file_analyzer import TimFileAnalyzer +from .sandbox import tempopulsar as sandbox_tempopulsar, Policy, configure_logging # noqa: F401 +from .tim_file_analyzer import TimFileAnalyzer # noqa: F401 # check to see if TEMPO2 environment variable is set TEMPO2_RUNTIME = os.getenv("TEMPO2") diff --git a/libstempo/sandbox.py b/libstempo/sandbox.py index 5e622e5..dd78921 100644 --- a/libstempo/sandbox.py +++ b/libstempo/sandbox.py @@ -920,7 +920,9 @@ def _proactive_nobs_setup(self): maxobs_with_margin = int(toa_count * self._policy.nobs_safety_margin) self._ctor_kwargs["maxobs"] = maxobs_with_margin logger.info( - f"Proactively added maxobs={maxobs_with_margin} parameter (TOAs: {toa_count}, threshold: {self._policy.nobs_threshold}, margin: {self._policy.nobs_safety_margin})" + f"Proactively added maxobs={maxobs_with_margin} parameter " + f"(TOAs: {toa_count}, threshold: {self._policy.nobs_threshold}, " + f"margin: {self._policy.nobs_safety_margin})" ) else: logger.debug( @@ -948,7 +950,8 @@ def _should_recycle(self) -> bool: # Check call limit (if set) if self._policy.max_calls_per_worker is not None and self._state.calls_ok >= self._policy.max_calls_per_worker: logger.info( - f"Should recycle: calls_ok {self._state.calls_ok} exceeds max_calls_per_worker {self._policy.max_calls_per_worker}" + f"Should recycle: calls_ok {self._state.calls_ok} exceeds " + f"max_calls_per_worker {self._policy.max_calls_per_worker}" ) return True diff --git a/libstempo/tim_file_analyzer.py b/libstempo/tim_file_analyzer.py index d40ae36..23ba0cd 100644 --- a/libstempo/tim_file_analyzer.py +++ b/libstempo/tim_file_analyzer.py @@ -122,37 +122,41 @@ def _toa_format(line: str, fmt: str = "Unknown") -> str: """Determine the type of a TOA line. - + Identifies a TOA line as one of the following types: Comment, Command, Blank, Tempo2, Princeton, ITOA, Parkes, Unknown. - + This replicates PINT's _toa_format function. """ # Check for comments first if ( - line.startswith("C ") and len(line) > 2 and not line[2].isdigit() # C followed by non-digit - or line.startswith("c ") and len(line) > 2 and not line[2].isdigit() # c followed by non-digit + line.startswith("C ") + and len(line) > 2 + and not line[2].isdigit() # C followed by non-digit + or line.startswith("c ") + and len(line) > 2 + and not line[2].isdigit() # c followed by non-digit or line.startswith("#") or line.startswith("CC ") ): return "Comment" - + # Check for commands if line.upper().lstrip().startswith(TOA_COMMANDS): return "Command" - + # Check for blank lines if re.match(r"^\s*$", line): return "Blank" - + # Check for Princeton format: starts with single observatory code followed by space if re.match(r"[0-9a-zA-Z@] ", line): return "Princeton" - + # Check for Tempo2 format: long lines, explicitly marked as Tempo2, or structured like Tempo2 if len(line) > 80 or fmt == "Tempo2": return "Tempo2" - + # Additional Tempo2 detection: if it looks like a Tempo2 TOA line (has 5+ space-separated fields) fields = line.split() if len(fields) >= 5: @@ -160,34 +164,34 @@ def _toa_format(line: str, fmt: str = "Unknown") -> str: try: # Try to parse as Tempo2: name freq mjd error obs float(fields[1]) # frequency should be numeric - float(fields[2]) # MJD should be numeric + float(fields[2]) # MJD should be numeric float(fields[3]) # error should be numeric # If we get here, it looks like Tempo2 format return "Tempo2" except (ValueError, IndexError): pass - + # Check for Parkes format: starts with space, has decimal at position 42 if re.match(r"^ ", line) and len(line) > 41 and line[41] == ".": return "Parkes" - + # Check for ITOA format: two non-space chars, decimal at position 15 if re.match(r"\S\S", line) and len(line) > 14 and line[14] == ".": return "ITOA" - + # Default to Unknown return "Unknown" def _get_observatory_name(obs_code: str) -> str: """Get observatory name from observatory code. - + This is a simplified version of PINT's get_observatory function that only handles the most common observatory codes. - + Args: obs_code: Observatory code (e.g., 'A', 'AO', '@') - + Returns: Observatory name """ @@ -197,17 +201,17 @@ def _get_observatory_name(obs_code: str) -> str: def _parse_TOA_line(line: str, fmt: str = "Unknown") -> Tuple[Optional[Tuple[int, float]], dict]: """Parse a one-line ASCII time-of-arrival. - + Return an MJD tuple and a dictionary of other TOA information. The format can be one of: Comment, Command, Blank, Tempo2, Princeton, ITOA, Parkes, or Unknown. - + This replicates PINT's _parse_TOA_line function. """ MJD = None fmt = _toa_format(line, fmt) d = dict(format=fmt) - + if fmt == "Princeton": # Princeton format # ---------------- @@ -259,7 +263,7 @@ def _parse_TOA_line(line: str, fmt: str = "Unknown") -> Tuple[Optional[Tuple[int # If parsing fails, treat as unknown format logger.debug(f"Failed to parse Princeton format line: {e}") d["format"] = "Unknown" - + elif fmt == "Tempo2": # This could use more error catching... try: @@ -275,14 +279,14 @@ def _parse_TOA_line(line: str, fmt: str = "Unknown") -> Tuple[Optional[Tuple[int d["obs"] = _get_observatory_name(fields[4].upper()) # All the rest should be flags flags = fields[5:] - + # Flags and flag-values should be given in pairs. # The for loop below will fail otherwise. if len(flags) % 2 != 0: raise ValueError( f"Flags and flag-values should be given in pairs. The given flags are {' '.join(flags)}" ) - + for i in range(0, len(flags), 2): k, v = flags[i].lstrip("-"), flags[i + 1] if k in ["error", "freq", "scale", "MJD", "flags", "obs", "name"]: @@ -294,7 +298,7 @@ def _parse_TOA_line(line: str, fmt: str = "Unknown") -> Tuple[Optional[Tuple[int # If parsing fails, treat as unknown format logger.debug(f"Failed to parse Tempo2 format line: {e}") d["format"] = "Unknown" - + elif fmt == "Command": d[fmt] = line.split() elif fmt == "Parkes": @@ -315,22 +319,18 @@ def _parse_TOA_line(line: str, fmt: str = "Unknown") -> Tuple[Optional[Tuple[int MJD = int(ii), float(f"0.{ff}") phaseoffset = float(line[55:62]) if phaseoffset != 0: - raise ValueError( - f"Cannot interpret Parkes format with phaseoffset={phaseoffset} yet" - ) + raise ValueError(f"Cannot interpret Parkes format with phaseoffset={phaseoffset} yet") d["error"] = float(line[63:71]) d["obs"] = _get_observatory_name(line[79].upper()) except (ValueError, IndexError) as e: # If parsing fails, treat as unknown format logger.debug(f"Failed to parse Parkes format line: {e}") d["format"] = "Unknown" - + elif fmt == "ITOA": raise RuntimeError(f"TOA format '{fmt}' not implemented yet") elif fmt not in ["Blank", "Comment"]: - raise RuntimeError( - f"Unable to identify TOA format for line {line!r}, expecting {fmt}" - ) + raise RuntimeError(f"Unable to identify TOA format for line {line!r}, expecting {fmt}") return MJD, d @@ -340,7 +340,7 @@ class TimFileAnalyzer: This class efficiently extracts TOA MJD values from TIM files using independent parsing logic that replicates PINT's functionality, providing both performance and robustness for timespan calculations. - + The analyzer caches results per file to avoid duplicate parsing when both timespan and TOA count are needed for the same file. """ @@ -381,18 +381,14 @@ def _get_timespan_and_count(self, tim_file_path: Path) -> Tuple[float, int]: self._file_cache[tim_file_path] = (timespan, toa_count) if toa_count > 0: - self.logger.debug( - f"Cached data for {tim_file_path}: {timespan:.1f} days, {toa_count} TOAs" - ) + self.logger.debug(f"Cached data for {tim_file_path}: {timespan:.1f} days, {toa_count} TOAs") else: self.logger.debug(f"Cached data for {tim_file_path}: No TOAs found") return timespan, toa_count except Exception as e: self.logger.warning(f"Parsing failed for {tim_file_path}: {e}") - self.logger.debug( - "File may contain non-standard TIM format or malformed data" - ) + self.logger.debug("File may contain non-standard TIM format or malformed data") # Cache empty result to avoid repeated failures empty_result = (0.0, 0) self._file_cache[tim_file_path] = empty_result @@ -413,9 +409,7 @@ def calculate_timespan(self, tim_file_path: Path) -> float: self.logger.warning(f"No TOAs found in {tim_file_path}") return 0.0 - self.logger.debug( - f"Timespan for {tim_file_path}: {timespan:.1f} days ({toa_count} TOAs)" - ) + self.logger.debug(f"Timespan for {tim_file_path}: {timespan:.1f} days ({toa_count} TOAs)") return timespan def count_toas(self, tim_file_path: Path) -> int: @@ -429,9 +423,7 @@ def count_toas(self, tim_file_path: Path) -> int: """ _, toa_count = self._get_timespan_and_count(tim_file_path) - self.logger.debug( - f"TOA count for {tim_file_path}: {toa_count} TOAs" - ) + self.logger.debug(f"TOA count for {tim_file_path}: {toa_count} TOAs") return toa_count def clear_cache(self) -> None: @@ -454,9 +446,7 @@ def get_timespan_and_count(self, tim_file_path: Path) -> Tuple[float, int]: self.logger.warning(f"No TOAs found in {tim_file_path}") return 0.0, 0 - self.logger.debug( - f"Timespan and count for {tim_file_path}: {timespan:.1f} days, {toa_count} TOAs" - ) + self.logger.debug(f"Timespan and count for {tim_file_path}: {timespan:.1f} days, {toa_count} TOAs") return timespan, toa_count def _extract_mjd_values_recursive(self, tim_file_path: Path) -> List[float]: @@ -491,9 +481,7 @@ def _extract_mjd_values_recursive(self, tim_file_path: Path) -> List[float]: mjd_tuple, d = _parse_TOA_line(line) except Exception as e: # Parsing may fail on malformed lines - skip them gracefully - self.logger.debug( - f"Skipping malformed line in {tim_file_path}: {line.strip()} - {e}" - ) + self.logger.debug(f"Skipping malformed line in {tim_file_path}: {line.strip()} - {e}") continue # Handle commands (especially INCLUDE) @@ -516,9 +504,7 @@ def _extract_mjd_values_recursive(self, tim_file_path: Path) -> List[float]: return mjd_values - def _handle_command( - self, d: dict, current_file: Path, mjd_values: List[float] - ) -> None: + def _handle_command(self, d: dict, current_file: Path, mjd_values: List[float]) -> None: """Handle TIM file commands using parsed command data. Args: From 5d24e4925430f7e1d81295695a34049b89a1a6c9 Mon Sep 17 00:00:00 2001 From: Rutger van Haasteren Date: Sat, 11 Oct 2025 09:20:46 +0200 Subject: [PATCH 07/16] Trigger CI re-run From 179e7981b36470f818539790d147c1964cc9f650 Mon Sep 17 00:00:00 2001 From: Rutger van Haasteren Date: Sat, 11 Oct 2025 22:10:33 +0200 Subject: [PATCH 08/16] =?UTF-8?q?sandbox:=20stdio=20routing=20(FD1?= =?UTF-8?q?=E2=86=92FD2),=20stderr=20drain=20thread,=20logs(tail);=20tests?= =?UTF-8?q?:=20parity=20+=20logs=20via=20savepar()?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- libstempo/sandbox.py | 65 +++++++++++++++++++++++--------------- tests/test_sandbox.py | 73 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 26 deletions(-) diff --git a/libstempo/sandbox.py b/libstempo/sandbox.py index dd78921..8c0e6e0 100644 --- a/libstempo/sandbox.py +++ b/libstempo/sandbox.py @@ -201,6 +201,13 @@ def _worker_stdio_main() -> None: Each request's 'params_b64' is a pickled dict of parameters. Each response uses 'result_b64' for Python results, or 'error'. """ + # Permanently redirect C-level stdout (FD 1) to stderr (FD 2), + # while keeping JSON-RPC on a dedicated duplicate of the original stdout pipe. + import os as _os_for_fds + _proto_fd = _os_for_fds.dup(1) # save original stdout FD for protocol + _os_for_fds.dup2(2, 1) # route any C/printf stdout to stderr + sys.stdout = _os_for_fds.fdopen(_proto_fd, "w", buffering=1) + # Step 1: hello handshake hello = { "hello": { @@ -304,32 +311,9 @@ def _write_response(resp: Dict[str, Any]) -> None: if _lib_tempopulsar is None: raise ImportError("libstempo not available in worker") - # Suppress stdout/stderr during constructor to prevent libstempo debug output - # from contaminating the JSON-RPC protocol. We need to redirect at the OS level - # because tempo2 writes directly to file descriptors. - import os - - # Save original stdout/stderr file descriptors - original_stdout = os.dup(1) - original_stderr = os.dup(2) - - try: - # Redirect stdout/stderr to /dev/null - devnull = os.open(os.devnull, os.O_WRONLY) - os.dup2(devnull, 1) # stdout - os.dup2(devnull, 2) # stderr - - obj = _lib_tempopulsar(**params["kwargs"]) - if params.get("preload_residuals", True): - _ = obj.residuals(updatebats=True, formresiduals=True) - - finally: - # Restore original stdout/stderr - os.dup2(original_stdout, 1) - os.dup2(original_stderr, 2) - os.close(devnull) - os.close(original_stdout) - os.close(original_stderr) + obj = _lib_tempopulsar(**params["kwargs"]) + if params.get("preload_residuals", True): + _ = obj.residuals(updatebats=True, formresiduals=True) _write_response( { @@ -448,6 +432,26 @@ def _start(self, require_x86_64: bool = False): logger.debug(f"Worker process started with PID: {self.proc.pid}") + # Start background stderr drain to avoid backpressure and capture logs + import threading, collections + self._log_buf = collections.deque(maxlen=20000) + + def _drain_stderr(pipe, sink_deque): + try: + for line in iter(pipe.readline, ''): + line = line.rstrip('\n') + sink_deque.append(line) + logger.debug("[tempo2-stderr] %s", line) + finally: + with contextlib.suppress(Exception): + pipe.close() + + if self.proc.stderr is not None: + self._stderr_thread = threading.Thread( + target=_drain_stderr, args=(self.proc.stderr, self._log_buf), daemon=True + ) + self._stderr_thread.start() + # Hello handshake (one line of JSON) logger.debug("Waiting for worker hello handshake...") hello = self._readline_with_timeout(self.policy.call_timeout_s) @@ -643,6 +647,12 @@ def rss(self) -> Optional[int]: logger.warning(f"Failed to get RSS: {e}") return None + def logs(self, tail: int = 500) -> str: + try: + return "\n".join(list(self._log_buf)[-max(0, tail):]) + except Exception: + return "" + # ------------------------- Command resolution (env_name) -------------------- # @@ -1062,6 +1072,9 @@ def toas(self, **kwargs): def fit(self, **kwargs): return self._rpc("call", name="fit", kwargs=kwargs) + def logs(self, tail: int = 500) -> str: + return self._wp.logs(tail) if self._wp else "" + def __del__(self): with contextlib.suppress(Exception): if self._wp is not None: diff --git a/tests/test_sandbox.py b/tests/test_sandbox.py index 49256e5..1bca08e 100644 --- a/tests/test_sandbox.py +++ b/tests/test_sandbox.py @@ -2,6 +2,9 @@ import libstempo as t2 from libstempo.sandbox import tempopulsar, Policy, configure_logging from libstempo.tim_file_analyzer import TimFileAnalyzer +import tempfile +import numpy as np +from numpy.testing import assert_allclose class TestSandbox(unittest.TestCase): @@ -42,6 +45,76 @@ def test_logging_configuration(self): configure_logging(level="DEBUG", enable_console=False) configure_logging(level="INFO", enable_console=True) + def test_logs_readout(self): + """Test that logs() captures tempo2 stdout/stderr by invoking savepar(), which prints.""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile, dofit=False) + # Baseline logs + logs_before = psr.logs(2000) + self.assertIsInstance(logs_before, str) + # Invoke an operation that emits tempo2 text output + tmp_par = tempfile.NamedTemporaryFile(delete=True) + tmp_par.close() + _ = psr.savepar(tmp_par.name) + # Give background drain thread a moment to process + import time as _t + _t.sleep(0.1) + logs_after = psr.logs(8000) + self.assertIsInstance(logs_after, str) + # Expect noticeable output; check growth and presence of a known token + self.assertGreater(len(logs_after), len(logs_before)) + self.assertIn("Results for PSR", logs_after) + + def test_sandbox_native_parity(self): + """Compare key attributes and outputs between sandbox and native tempopulsar.""" + psr_s = tempopulsar(parfile=self.parfile, timfile=self.timfile) + psr_n = t2.tempopulsar(parfile=self.parfile, timfile=self.timfile) + + # Scalar/string attributes + self.assertEqual(psr_s.name, psr_n.name) + self.assertEqual(psr_s.nobs, psr_n.nobs) + + # Arrays: use a tight tolerance + assert_allclose(psr_s.stoas, psr_n.stoas, rtol=0, atol=0) + + # Residuals + res_s = psr_s.residuals() + res_n = psr_n.residuals() + self.assertEqual(res_s.shape, res_n.shape) + assert_allclose(res_s, res_n, rtol=0, atol=0) + + # Design matrix + dm_s = psr_s.designmatrix() + dm_n = psr_n.designmatrix() + self.assertEqual(dm_s.shape, dm_n.shape) + assert_allclose(dm_s, dm_n, rtol=0, atol=0) + + # TOAs + toas_s = psr_s.toas() + toas_n = psr_n.toas() + self.assertEqual(len(toas_s), len(toas_n)) + assert_allclose(np.asarray(toas_s), np.asarray(toas_n), rtol=0, atol=0) + + # Timing model parameters (subset): ensure values and metadata match + s_all = set(psr_s.pars(which="all")) + n_all = set(psr_n.pars(which="all")) + s_fit = set(psr_s.pars()) # defaults to fitted + n_fit = set(psr_n.pars()) + s_set = set(psr_s.pars(which="set")) + n_set = set(psr_n.pars(which="set")) + + for par_name in ["RAJ", "DECJ"]: + if par_name in s_all and par_name in n_all: + # numeric values and errors via bulk accessors + s_val = np.asarray(psr_s.vals(which=[par_name]))[0] + n_val = np.asarray(psr_n.vals(which=[par_name]))[0] + s_err = np.asarray(psr_s.errs(which=[par_name]))[0] + n_err = np.asarray(psr_n.errs(which=[par_name]))[0] + assert_allclose(s_val, n_val, rtol=0, atol=0) + assert_allclose(s_err, n_err, rtol=0, atol=0) + # fit/set flags via pars() groups + self.assertEqual(par_name in s_fit, par_name in n_fit) + self.assertEqual(par_name in s_set, par_name in n_set) + class TestTimFileAnalyzer(unittest.TestCase): @classmethod From 2aa631a1bf3c9d303630babd1490c7166f273a12 Mon Sep 17 00:00:00 2001 From: Rutger van Haasteren Date: Sat, 11 Oct 2025 22:15:19 +0200 Subject: [PATCH 09/16] =?UTF-8?q?sandbox:=20stdio=20routing=20(FD1?= =?UTF-8?q?=E2=86=92FD2),=20stderr=20drain=20thread,=20logs(tail);=20tests?= =?UTF-8?q?:=20parity=20+=20logs=20via=20savepar()?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- libstempo/sandbox.py | 14 +++++++++----- tests/test_sandbox.py | 3 ++- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/libstempo/sandbox.py b/libstempo/sandbox.py index 8c0e6e0..95c925d 100644 --- a/libstempo/sandbox.py +++ b/libstempo/sandbox.py @@ -1,3 +1,4 @@ +# flake8: noqa: E501 """ Author: Rutger van Haasteren -- rutger@vhaasteren.com Date: 2025-10-10 @@ -204,6 +205,7 @@ def _worker_stdio_main() -> None: # Permanently redirect C-level stdout (FD 1) to stderr (FD 2), # while keeping JSON-RPC on a dedicated duplicate of the original stdout pipe. import os as _os_for_fds + _proto_fd = _os_for_fds.dup(1) # save original stdout FD for protocol _os_for_fds.dup2(2, 1) # route any C/printf stdout to stderr sys.stdout = _os_for_fds.fdopen(_proto_fd, "w", buffering=1) @@ -311,7 +313,7 @@ def _write_response(resp: Dict[str, Any]) -> None: if _lib_tempopulsar is None: raise ImportError("libstempo not available in worker") - obj = _lib_tempopulsar(**params["kwargs"]) + obj = _lib_tempopulsar(**params["kwargs"]) if params.get("preload_residuals", True): _ = obj.residuals(updatebats=True, formresiduals=True) @@ -433,13 +435,15 @@ def _start(self, require_x86_64: bool = False): logger.debug(f"Worker process started with PID: {self.proc.pid}") # Start background stderr drain to avoid backpressure and capture logs - import threading, collections + import threading + import collections + self._log_buf = collections.deque(maxlen=20000) def _drain_stderr(pipe, sink_deque): try: - for line in iter(pipe.readline, ''): - line = line.rstrip('\n') + for line in iter(pipe.readline, ""): + line = line.rstrip("\n") sink_deque.append(line) logger.debug("[tempo2-stderr] %s", line) finally: @@ -649,7 +653,7 @@ def rss(self) -> Optional[int]: def logs(self, tail: int = 500) -> str: try: - return "\n".join(list(self._log_buf)[-max(0, tail):]) + return "\n".join(list(self._log_buf)[-max(0, tail) :]) except Exception: return "" diff --git a/tests/test_sandbox.py b/tests/test_sandbox.py index 1bca08e..9cd33e1 100644 --- a/tests/test_sandbox.py +++ b/tests/test_sandbox.py @@ -46,7 +46,7 @@ def test_logging_configuration(self): configure_logging(level="INFO", enable_console=True) def test_logs_readout(self): - """Test that logs() captures tempo2 stdout/stderr by invoking savepar(), which prints.""" + """Test that logs() captures tempo2 output via savepar().""" psr = tempopulsar(parfile=self.parfile, timfile=self.timfile, dofit=False) # Baseline logs logs_before = psr.logs(2000) @@ -57,6 +57,7 @@ def test_logs_readout(self): _ = psr.savepar(tmp_par.name) # Give background drain thread a moment to process import time as _t + _t.sleep(0.1) logs_after = psr.logs(8000) self.assertIsInstance(logs_after, str) From 50304f5959a51d61bb053b41db2eff4005514b31 Mon Sep 17 00:00:00 2001 From: Rutger van Haasteren Date: Sun, 12 Oct 2025 15:38:03 +0000 Subject: [PATCH 10/16] sandbox: robust, backward-compatible process isolation for tempopulsar - Protocol - Add hello proto_version=1.2 and capabilities: get_kind, dir, setitem, get_slice, path_access - Non-exceptional attribute discovery (get-kind) and optional dir RPC - Array semantics - Introduce write-through ArrayProxy for numpy-backed attrs (stoas, toaerrs, freqs) - Reads expose plain numpy via __array__; __repr__/__str__/__getattr__ delegate to ndarray - Writes route via setitem RPC; add get_slice RPC to avoid fetching whole arrays for reads - Guard __len__ for 0-d; support fancy/masked indexing; optional safe dtype cast on set - Dotted paths - Gate first-hop mapping access to mapping-like (__getitem__) objects only - Support psr['PAR'].val/err/fit/set via dotted-path resolution - Process lifecycle & IO - Popen: pass env, close_fds, start_new_session, Windows CREATE_NEW_PROCESS_GROUP when available - Group kill with POSIX killpg; Windows terminate/kill fallbacks - Thread-safe RPC framing with a per-worker send lock - Errors & logging - Stderr ring with optional tail included in exceptions; cap tail by bytes (16KiB) and lines - Tests - Add unit tests comparing sandbox vs native for parameter mapping and TOA edits+fit - Full suite green: array writes now update worker; residuals match native after fit --- libstempo/sandbox.py | 527 +++++++++++++++++++++++++++++++++++++++--- tests/test_sandbox.py | 63 +++++ 2 files changed, 560 insertions(+), 30 deletions(-) diff --git a/libstempo/sandbox.py b/libstempo/sandbox.py index 95c925d..9b7b7f3 100644 --- a/libstempo/sandbox.py +++ b/libstempo/sandbox.py @@ -80,6 +80,7 @@ from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Tuple +from collections.abc import Mapping # Import TimFileAnalyzer for proactive TOA counting from .tim_file_analyzer import TimFileAnalyzer @@ -142,6 +143,11 @@ class Policy: nobs_threshold: int = 10000 # add nobs parameter if TOA count exceeds this threshold nobs_safety_margin: float = 1.1 # multiplier for nobs parameter (e.g., 1.1 = 10% more than actual count) + # Logging / stderr capture + stderr_ring_max_lines: int = 20000 + stderr_log_file: Optional[str] = None + include_stderr_tail_in_errors: int = 200 # 0 disables tail inclusion + # -------------------------- Wire serialization helpers --------------------- # @@ -219,6 +225,8 @@ def _worker_stdio_main() -> None: "platform": platform.platform(), "has_libstempo": False, "tempo2_version": None, + "proto_version": "1.2", + "capabilities": {"get_kind": True, "dir": True, "setitem": True, "get_slice": True, "path_access": True}, } } try: @@ -331,24 +339,149 @@ def _write_response(resp: Dict[str, Any]) -> None: if method == "get": name = params["name"] - val = getattr(obj, name) - # copy numpy views to decouple from lib memory + # Support dotted path and mapping access for parameters (e.g., 'RAJ.val') + parts = str(name).split(".") if isinstance(name, str) else [name] + cur = obj + missing = False + for idx, part in enumerate(parts): + # First hop supports attribute or mapping access + if idx == 0: + try: + if hasattr(cur, part): + cur = getattr(cur, part) + elif isinstance(cur, Mapping) or hasattr(cur, "__getitem__"): + cur = cur[part] + else: + raise AttributeError + except Exception: + missing = True + break + else: + try: + cur = getattr(cur, part) + except Exception: + missing = True + break + + if missing: + _write_response( + {"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py({"kind": "missing", "value": None})} + ) + continue + + # cur is the resolved object/value + if callable(cur): + _write_response( + {"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py({"kind": "callable", "value": None})} + ) + continue + try: - import numpy as _np2 # local alias + import numpy as _np2 - if hasattr(val, "base") and isinstance(val, _np2.ndarray): - val = val.copy() + if isinstance(cur, _np2.ndarray): + cur = cur.copy(order="C") + elif isinstance(cur, _np2.generic): + cur = cur.item() except Exception: pass - _write_response({"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py(val)}) + + _write_response( + {"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py({"kind": "value", "value": cur})} + ) continue if method == "set": name, value = params["name"], params["value"] - setattr(obj, name, value) + # Support dotted path and mapping access for parameters (e.g., 'RAJ.val') + parts = str(name).split(".") if isinstance(name, str) else [name] + cur = obj + missing = False + # Traverse to parent of target + for idx, part in enumerate(parts[:-1]): + try: + if idx == 0: + if hasattr(cur, part): + cur = getattr(cur, part) + elif isinstance(cur, Mapping) or hasattr(cur, "__getitem__"): + cur = cur[part] + else: + raise AttributeError + else: + cur = getattr(cur, part) + except Exception: + missing = True + break + if missing: + _write_response( + { + "jsonrpc": "2.0", + "id": rid, + "error": { + "code": -32000, + "message": f"AttributeError: cannot resolve path for set: {name}", + "data": "", + }, + } + ) + continue + target = parts[-1] + try: + setattr(cur, target, value) + except Exception: + et, ev, tb = _format_exc_tuple() + _write_response( + {"jsonrpc": "2.0", "id": rid, "error": {"code": -32000, "message": f"{et}: {ev}", "data": tb}} + ) + continue _write_response({"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py(None)}) continue + if method == "setitem": + # Set slice(s) on numpy array attributes like stoas, toaerrs + name = params["name"] + index = params["index"] + value = params["value"] + try: + arr = getattr(obj, name) + try: + import numpy as _np2 + + if isinstance(value, _np2.ndarray) and not _np2.can_cast( + value.dtype, arr.dtype, casting="safe" + ): + value = value.astype(arr.dtype, copy=False) + except Exception: + pass + arr[index] = value + _write_response({"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py(None)}) + except Exception: + et, ev, tb = _format_exc_tuple() + _write_response( + {"jsonrpc": "2.0", "id": rid, "error": {"code": -32000, "message": f"{et}: {ev}", "data": tb}} + ) + continue + + if method == "get_slice": + name = params["name"] + index = params["index"] + try: + arr = getattr(obj, name) + import numpy as _np2 + + out = arr[index] + if isinstance(out, _np2.ndarray): + out = out.copy(order="C") + elif isinstance(out, _np2.generic): + out = out.item() + _write_response({"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py(out)}) + except Exception: + et, ev, tb = _format_exc_tuple() + _write_response( + {"jsonrpc": "2.0", "id": rid, "error": {"code": -32000, "message": f"{et}: {ev}", "data": tb}} + ) + continue + if method == "call": name = params["name"] args = tuple(params.get("args", ())) @@ -358,8 +491,11 @@ def _write_response(resp: Dict[str, Any]) -> None: try: import numpy as _np2 - if hasattr(out, "base") and isinstance(out, _np2.ndarray): - out = out.copy() + if isinstance(out, _np2.ndarray): + # Always copy numpy arrays to avoid C++ object references + out = out.copy(order="C") + elif isinstance(out, _np2.generic): + out = out.item() except Exception: pass _write_response({"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py(out)}) @@ -374,6 +510,15 @@ def _write_response(resp: Dict[str, Any]) -> None: _write_response({"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py(None)}) continue + if method == "dir": + names = [] + for n in dir(obj): + if not n.startswith("_"): + names.append(n) + names.sort() + _write_response({"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py(names)}) + continue + _write_response( { "jsonrpc": "2.0", @@ -422,6 +567,12 @@ def _start(self, require_x86_64: bool = False): logger.debug(f"Launching subprocess with environment: PYTHONUNBUFFERED={env.get('PYTHONUNBUFFERED')}") logger.debug(f"Subprocess working directory: {os.getcwd()}") + creationflags = 0 + if os.name == "nt": + try: + creationflags = subprocess.CREATE_NEW_PROCESS_GROUP # type: ignore[attr-defined] + except Exception: + creationflags = 0 self.proc = subprocess.Popen( self.cmd, stdin=subprocess.PIPE, @@ -430,6 +581,10 @@ def _start(self, require_x86_64: bool = False): text=True, bufsize=1, # line buffered cwd=os.getcwd(), # Explicitly set working directory + env=env, + close_fds=True, + start_new_session=True, + creationflags=creationflags, ) logger.debug(f"Worker process started with PID: {self.proc.pid}") @@ -438,21 +593,36 @@ def _start(self, require_x86_64: bool = False): import threading import collections - self._log_buf = collections.deque(maxlen=20000) + self._log_buf = collections.deque(maxlen=self.policy.stderr_ring_max_lines) + log_file = None + if self.policy.stderr_log_file: + try: + log_file = open(self.policy.stderr_log_file, "a", buffering=1, encoding="utf-8") + except Exception: + log_file = None - def _drain_stderr(pipe, sink_deque): + def _drain_stderr(pipe, sink_deque, sink_file): try: for line in iter(pipe.readline, ""): line = line.rstrip("\n") sink_deque.append(line) + if sink_file: + try: + sink_file.write(line + "\n") + except Exception: + pass logger.debug("[tempo2-stderr] %s", line) finally: with contextlib.suppress(Exception): pipe.close() + if sink_file: + with contextlib.suppress(Exception): + sink_file.flush() + sink_file.close() if self.proc.stderr is not None: self._stderr_thread = threading.Thread( - target=_drain_stderr, args=(self.proc.stderr, self._log_buf), daemon=True + target=_drain_stderr, args=(self.proc.stderr, self._log_buf, log_file), daemon=True ) self._stderr_thread.start() @@ -478,6 +648,11 @@ def _drain_stderr(pipe, sink_deque): info = hello_obj.get("hello", {}) logger.info(f"Worker hello received: {info}") + self._proto_version = info.get("proto_version", "1.0") + caps = info.get("capabilities") or {} + self._cap_get_kind = bool(caps.get("get_kind")) + self._cap_dir = bool(caps.get("dir")) + self._cap_get_slice = bool(caps.get("get_slice")) if require_x86_64: if str(info.get("machine", "")).lower() != "x86_64": @@ -526,19 +701,30 @@ def _readline_with_timeout(self, timeout: Optional[float]) -> Optional[str]: def _hard_kill(self): if self.proc and self.proc.poll() is None: - logger.warning(f"Hard killing worker process (PID: {self.proc.pid})") + logger.debug(f"Hard killing worker process (PID: {self.proc.pid})") try: - self.proc.terminate() + if os.name == "nt": + with contextlib.suppress(Exception): + self.proc.terminate() + else: + os.killpg(self.proc.pid, signal.SIGTERM) except Exception as e: - logger.warning(f"Failed to terminate process: {e}") - pass + logger.warning(f"Failed to terminate process group: {e}; falling back to terminate()") + with contextlib.suppress(Exception): + self.proc.terminate() t0 = time.time() while self.proc.poll() is None and (time.time() - t0) < self.policy.kill_grace_s: time.sleep(0.01) if self.proc.poll() is None: - logger.warning(f"Sending SIGKILL to worker process (PID: {self.proc.pid})") + logger.debug(f"Sending SIGKILL to worker process (PID: {self.proc.pid})") with contextlib.suppress(Exception): - os.kill(self.proc.pid, signal.SIGKILL) + try: + if os.name == "nt": + self.proc.kill() + else: + os.killpg(self.proc.pid, signal.SIGKILL) + except Exception: + os.kill(self.proc.pid, signal.SIGKILL) self.proc = None def close(self): @@ -567,7 +753,9 @@ def _send_rpc(self, method: str, params: Dict[str, Any], timeout: Optional[float self._id += 1 rid = self._id - logger.debug(f"Sending RPC {method} (id: {rid})") + # Only log debug for non-get methods to reduce noise + if method != "get": + logger.debug(f"Sending RPC {method} (id: {rid})") frame = { "jsonrpc": "2.0", @@ -577,9 +765,15 @@ def _send_rpc(self, method: str, params: Dict[str, Any], timeout: Optional[float } line = json.dumps(frame) + "\n" + # Protect frames from interleaving + import threading + + if not hasattr(self, "_rpc_lock"): + self._rpc_lock = threading.Lock() try: - self.proc.stdin.write(line) - self.proc.stdin.flush() + with self._rpc_lock: + self.proc.stdin.write(line) + self.proc.stdin.flush() except Exception as e: logger.error(f"Failed to send RPC {method}: {e}") self._hard_kill() @@ -619,9 +813,23 @@ def _send_rpc(self, method: str, params: Dict[str, Any], timeout: Optional[float msg = err.get("message", "error") data = err.get("data", "") logger.error(f"RPC {method} failed: {msg}") - raise Tempo2Error(f"{msg}\n{data}") + tail = "" + if getattr(self, "_log_buf", None) and (self.policy.include_stderr_tail_in_errors or 0) > 0: + try: + tail_lines = list(self._log_buf)[-self.policy.include_stderr_tail_in_errors :] + if tail_lines: + blob = "\n".join(tail_lines) + max_bytes = 16384 + if len(blob) > max_bytes: + blob = blob[-max_bytes:] + tail = "\n--- tempo2 stderr (tail) ---\n" + blob + except Exception: + tail = "" + raise Tempo2Error(f"{msg}\n{data}{tail}") - logger.debug(f"RPC {method} completed successfully") + # Only log debug for non-get methods to reduce noise + if method != "get": + logger.debug(f"RPC {method} completed successfully") result_b64 = resp.get("result_b64", None) return _b64_loads_py(result_b64) if result_b64 is not None else None @@ -635,6 +843,21 @@ def get(self, name: str): logger.debug(f"Getting attribute: {name}") return self._send_rpc("get", {"name": name}) + def get_kind(self, name: str): + """Return (kind, value) where kind in {"value","callable","missing"}. + For legacy workers without get-kind support, assumes raw value => ("value", value). + """ + resp = self._send_rpc("get", {"name": name}) + if isinstance(resp, dict) and "kind" in resp: + return (resp.get("kind"), resp.get("value")) + return ("value", resp) + + def dir(self): + """Return list of public attribute names if worker supports dir RPC; else empty list.""" + if getattr(self, "_cap_dir", False): + return self._send_rpc("dir", {}) + return [] + def set(self, name: str, value: Any): logger.debug(f"Setting attribute: {name}") return self._send_rpc("set", {"name": name, "value": value}) @@ -643,6 +866,14 @@ def call(self, name: str, args=(), kwargs=None): logger.debug(f"Calling method: {name} with args={args}, kwargs={kwargs}") return self._send_rpc("call", {"name": name, "args": tuple(args), "kwargs": dict(kwargs or {})}) + def setitem(self, name: str, index, value: Any): + logger.debug(f"Setting array slice: {name}[{index}] = ") + return self._send_rpc("setitem", {"name": name, "index": index, "value": value}) + + def get_slice(self, name: str, index): + logger.debug(f"Getting array slice: {name}[{index}]") + return self._send_rpc("get_slice", {"name": name, "index": index}) + def rss(self) -> Optional[int]: try: logger.debug("Getting worker RSS memory usage") @@ -874,6 +1105,26 @@ def __init__(self, env_name: Optional[str] = None, **kwargs): def _construct_with_retries(self): logger.info(f"Starting construction with {self._policy.ctor_retry + 1} total attempts") + # Fast-fail on missing input files to avoid noisy retries + try: + parfile = self._ctor_kwargs.get("parfile") + timfile = self._ctor_kwargs.get("timfile") + cwd = os.getcwd() + if parfile and not Path(parfile).exists(): + raise Tempo2ConstructorFailed( + f"parfile not found: {parfile} (cwd: {cwd}). Provide an absolute path or correct relative path." + ) + if timfile and not Path(timfile).exists(): + raise Tempo2ConstructorFailed( + f"timfile not found: {timfile} (cwd: {cwd}). Provide an absolute path or correct relative path." + ) + except Tempo2ConstructorFailed: + # Re-raise to surface a clean single error without retries + raise + except Exception: + # Do not block construction for unexpected preflight errors + logger.debug("Preflight path check skipped due to unexpected error:", exc_info=True) + # Proactive TOA counting to avoid "Too many TOAs" errors if self._policy.auto_nobs_retry: self._proactive_nobs_setup() @@ -898,6 +1149,14 @@ def _construct_with_retries(self): except Exception as e: logger.warning(f"Construction attempt {attempt + 1} failed: {e}") last_exc = e + # If it's a file-not-found style error, fail fast without retries + msg = str(e) + if any( + t in msg + for t in ("Cannot find parfile", "Cannot find timfile", "parfile not found", "timfile not found") + ): + logger.error("Input file missing; not retrying constructor.") + break # kill and retry try: if self._wp: @@ -1006,14 +1265,18 @@ def _rpc(self, call: str, **payload): out = self._wp.set(payload["name"], payload["value"]) elif call == "call": out = self._wp.call(payload["name"], payload.get("args", ()), payload.get("kwargs", {})) + elif call == "setitem": + out = self._wp.setitem(payload["name"], payload.get("index"), payload.get("value")) else: raise Tempo2ProtocolError(f"unknown call {call}") self._state.calls_ok += 1 logger.debug(f"RPC {call} successful, total calls: {self._state.calls_ok}") return out except (Tempo2Timeout, Tempo2Crashed, Tempo2ProtocolError, Tempo2Error) as e: - logger.warning(f"RPC {call} failed with {type(e).__name__}: {e}") - logger.info("Attempting automatic worker recycle and retry") + # Only log warnings for actual failures, not expected attribute discovery failures + if call != "get" or not str(e).startswith("AttributeError"): + logger.warning(f"RPC {call} failed with {type(e).__name__}: {e}") + logger.info("Attempting automatic worker recycle and retry") # automatic one-time recycle on a fresh worker self._recycle() assert self._wp is not None @@ -1024,7 +1287,8 @@ def _rpc(self, call: str, **payload): else: out = self._wp.call(payload["name"], payload.get("args", ()), payload.get("kwargs", {})) self._state.calls_ok += 1 - logger.info(f"RPC {call} succeeded after recycle, total calls: {self._state.calls_ok}") + if call != "get" or not str(e).startswith("AttributeError"): + logger.info(f"RPC {call} succeeded after recycle, total calls: {self._state.calls_ok}") return out # ------------------------ Attribute proxying magic ------------------------ # @@ -1048,14 +1312,28 @@ def __getattr__(self, name: str): def _remote_method(*args, **kwargs): return self._rpc("call", name=name, args=args, kwargs=kwargs) - # Try a GET first; if it errors, assume it's a method + # Non-exceptional discovery using get-kind if available + try: + if self._wp and getattr(self._wp, "_cap_get_kind", False): + kind, payload = self._wp.get_kind(name) + if kind == "value": + return payload + if kind == "callable": + return _remote_method + if kind == "missing": + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + raise Tempo2ProtocolError(f"unexpected get-kind '{kind}' for attribute '{name}'") + except Exception: + pass + + # Legacy fallback: try get and assume callable on failure try: val = self._rpc("get", name=name) - except Tempo2Error: + if not callable(val): + return val return _remote_method - if callable(val): + except (Tempo2Error, Tempo2Timeout, Tempo2Crashed, Tempo2ProtocolError): return _remote_method - return val def __setattr__(self, name: str, value: Any): if name in tempopulsar.__slots__: @@ -1063,6 +1341,39 @@ def __setattr__(self, name: str, value: Any): _ = self._rpc("set", name=name, value=value) return None + def __dir__(self): + """Return a list of available attributes for dir() function.""" + try: + if self._wp and getattr(self._wp, "_cap_dir", False): + return list(self._wp.dir()) + except Exception: + pass + + # Fallback: return a basic set of common tempopulsar attributes + return [ + "name", + "nobs", + "stoas", + "toaerrs", + "freqs", + "ndim", + "residuals", + "designmatrix", + "toas", + "fit", + "vals", + "errs", + "pars", + "flags", + "flagvals", + "savepar", + "savetim", + "chisq", + "rms", + "ssbfreqs", + "logs", + ] + # Explicit helpers for common call shapes def residuals(self, **kwargs): return self._rpc("call", name="residuals", kwargs=kwargs) @@ -1079,6 +1390,162 @@ def fit(self, **kwargs): def logs(self, tail: int = 500) -> str: return self._wp.logs(tail) if self._wp else "" + # Mapping-style access to parameters, proxied to libstempo + def __getitem__(self, key: str): + return _ParamProxy(self, key) + + def __setitem__(self, key: str, value: Any): + raise TypeError("Direct assignment to parameters is not supported; set fields like psr['RAJ'].val = x") + + # Expose array-like attributes as write-through proxies + @property + def stoas(self): + return _ArrayProxy(self, "stoas") + + @property + def toaerrs(self): + return _ArrayProxy(self, "toaerrs") + + @property + def freqs(self): + return _ArrayProxy(self, "freqs") + + +class _ParamProxy: + __slots__ = ("_parent", "_name") + + def __init__(self, parent: tempopulsar, name: str) -> None: + object.__setattr__(self, "_parent", parent) + object.__setattr__(self, "_name", name) + + def __getattr__(self, attr: str): + # Fetch field via dotted get path (e.g., RAJ.val), honoring get-kind + try: + if self._parent._wp and getattr(self._parent._wp, "_cap_get_kind", False): + kind, payload = self._parent._wp.get_kind(f"{self._name}.{attr}") + if kind == "value": + return payload + if kind == "callable": + # Expose a callable that routes via call with dotted name + def _remote_method(*args, **kwargs): + return self._parent._rpc("call", name=f"{self._name}.{attr}", args=args, kwargs=kwargs) + + return _remote_method + if kind == "missing": + raise AttributeError(f"'{self._name}' has no attribute '{attr}'") + except Exception: + pass + # Legacy fallback: direct get + resp = self._parent._rpc("get", name=f"{self._name}.{attr}") + if isinstance(resp, dict) and "kind" in resp: + if resp["kind"] == "value": + return resp.get("value") + if resp["kind"] == "callable": + + def _remote_method(*args, **kwargs): + return self._parent._rpc("call", name=f"{self._name}.{attr}", args=args, kwargs=kwargs) + + return _remote_method + raise AttributeError(f"'{self._name}' has no attribute '{attr}'") + return resp + + def __setattr__(self, attr: str, value: Any) -> None: + if attr in _ParamProxy.__slots__: + return object.__setattr__(self, attr, value) + _ = self._parent._rpc("set", name=f"{self._name}.{attr}", value=value) + return None + + def __repr__(self) -> str: + try: + v = self.__getattr__("val") + e = self.__getattr__("err") + return f"" + except Exception: + return f"" + + +class _ArrayProxy: + __slots__ = ("_parent", "_name") + + def __init__(self, parent: tempopulsar, name: str) -> None: + object.__setattr__(self, "_parent", parent) + object.__setattr__(self, "_name", name) + + # numpy reads + def __array__(self, dtype=None): + import numpy as _np + + arr = None + try: + # Prefer get-kind if supported to avoid unnecessary data shapes + if self._parent._wp and getattr(self._parent._wp, "_cap_get_kind", False): + kind, payload = self._parent._wp.get_kind(self._name) + if kind == "value": + arr = payload + elif kind == "callable": + # Arrays are not callable; treat as empty + arr = [] + else: + arr = [] + else: + resp = self._parent._rpc("get", name=self._name) + if isinstance(resp, dict) and "kind" in resp: + arr = resp.get("value") + else: + arr = resp + except Exception: + resp = self._parent._rpc("get", name=self._name) + if isinstance(resp, dict) and "kind" in resp: + arr = resp.get("value") + else: + arr = resp + a = _np.asarray(arr) + if dtype is not None: + a = a.astype(dtype, copy=False) + return a + + def __len__(self): + a = self.__array__() + return int(a.shape[0]) if getattr(a, "ndim", 1) > 0 else 1 + + @property + def shape(self): + return self.__array__().shape + + @property + def dtype(self): + return self.__array__().dtype + + # python indexing + def __getitem__(self, idx): + try: + if self._parent._wp and getattr(self._parent._wp, "_cap_get_slice", False): + return self._parent._wp.get_slice(self._name, idx) + except Exception: + pass + return self.__array__()[idx] + + def __setitem__(self, idx, value): + _ = self._parent._rpc("setitem", name=self._name, index=idx, value=value) + return None + + def __repr__(self) -> str: + try: + return repr(self.__array__()) + except Exception: + return f"<_ArrayProxy {self._name}>" + + def __str__(self) -> str: + try: + return str(self.__array__()) + except Exception: + return self.__repr__() + + # Delegate unknown attributes/methods to the numpy array + def __getattr__(self, name: str): + arr = self.__array__() + return getattr(arr, name) + def __del__(self): with contextlib.suppress(Exception): if self._wp is not None: diff --git a/tests/test_sandbox.py b/tests/test_sandbox.py index 9cd33e1..8a0dd2e 100644 --- a/tests/test_sandbox.py +++ b/tests/test_sandbox.py @@ -116,6 +116,69 @@ def test_sandbox_native_parity(self): self.assertEqual(par_name in s_fit, par_name in n_fit) self.assertEqual(par_name in s_set, par_name in n_set) + def test_param_proxy_accessors(self): + """Test psr['parname'].val/err/fit/set mapping accessors and roundtrips.""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile) + + # Choose parameters that are present and safe to touch + par_val = "RAJ" + par_fit = "DM" # commonly present and safe to toggle fit flag + + # Read val/err via mapping + v0 = psr[par_val].val + e0 = psr[par_val].err + self.assertIsInstance(float(v0), float) + self.assertIsInstance(float(e0), float) + + # Roundtrip val by setting the same value (as Python float) + psr[par_val].val = float(v0) + self.assertAlmostEqual(float(psr[par_val].val), float(v0), places=12) + + # Roundtrip err by setting the same value (as Python float) + psr[par_val].err = float(e0) + self.assertAlmostEqual(float(psr[par_val].err), float(e0), places=12) + + # Toggle fit flag and revert + fit0 = bool(psr[par_fit].fit) + psr[par_fit].fit = not fit0 + self.assertEqual(bool(psr[par_fit].fit), (not fit0)) + # revert + psr[par_fit].fit = fit0 + self.assertEqual(bool(psr[par_fit].fit), fit0) + + # 'set' flag should be boolean and readable; do not change it here + self.assertIsInstance(bool(psr[par_val].set), bool) + + def test_stoas_edit_and_fit_matches_native(self): + """Edit stoas and toaerrs, run fit, and compare residuals to native.""" + rng = np.random.default_rng(12345) + + # Sandbox and native + psr_s = tempopulsar(parfile=self.parfile, timfile=self.timfile) + psr_n = t2.tempopulsar(parfile=self.parfile, timfile=self.timfile) + + # Create identical noise realization + noise = 0.1e-6 * rng.standard_normal(psr_s.nobs) / 86400.0 + + # Apply to stoas and toaerrs in both + # Sandbox: use write-through proxies (backward compatible API) + psr_s.stoas[:] = psr_s.stoas[:] + noise + psr_s.toaerrs[:] = 0.1 + + # Native + psr_n.stoas[:] = psr_n.stoas + noise + psr_n.toaerrs[:] = 0.1 + + # Fit both + _ = psr_s.fit() + _ = psr_n.fit() + + # Compare residuals tightly + res_s = psr_s.residuals() + res_n = psr_n.residuals() + self.assertEqual(res_s.shape, res_n.shape) + assert_allclose(res_s, res_n, rtol=0, atol=0) + class TestTimFileAnalyzer(unittest.TestCase): @classmethod From 9cac16d902ff4991d967ec7c443093e0e98b851b Mon Sep 17 00:00:00 2001 From: Rutger van Haasteren Date: Sun, 12 Oct 2025 21:20:00 +0200 Subject: [PATCH 11/16] Added output of stdout/stderr to screen in real time --- libstempo/sandbox.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/libstempo/sandbox.py b/libstempo/sandbox.py index 9b7b7f3..c14233d 100644 --- a/libstempo/sandbox.py +++ b/libstempo/sandbox.py @@ -589,7 +589,7 @@ def _start(self, require_x86_64: bool = False): logger.debug(f"Worker process started with PID: {self.proc.pid}") - # Start background stderr drain to avoid backpressure and capture logs + # Start background stderr drain to capture logs AND output to real-time stderr import threading import collections @@ -606,6 +606,10 @@ def _drain_stderr(pipe, sink_deque, sink_file): for line in iter(pipe.readline, ""): line = line.rstrip("\n") sink_deque.append(line) + + # Write to real stderr for real-time output (native-like behavior) + print(line, file=sys.stderr, flush=True) + if sink_file: try: sink_file.write(line + "\n") From e16b10d7dc49e41d4ce32e76c3d0b98d21f7626e Mon Sep 17 00:00:00 2001 From: Rutger van Haasteren Date: Sun, 12 Oct 2025 21:56:27 +0200 Subject: [PATCH 12/16] removed: old legacy RPC path --- libstempo/sandbox.py | 119 ++++++++++++------------------------------- 1 file changed, 32 insertions(+), 87 deletions(-) diff --git a/libstempo/sandbox.py b/libstempo/sandbox.py index c14233d..8360eb9 100644 --- a/libstempo/sandbox.py +++ b/libstempo/sandbox.py @@ -653,10 +653,6 @@ def _drain_stderr(pipe, sink_deque, sink_file): info = hello_obj.get("hello", {}) logger.info(f"Worker hello received: {info}") self._proto_version = info.get("proto_version", "1.0") - caps = info.get("capabilities") or {} - self._cap_get_kind = bool(caps.get("get_kind")) - self._cap_dir = bool(caps.get("dir")) - self._cap_get_slice = bool(caps.get("get_slice")) if require_x86_64: if str(info.get("machine", "")).lower() != "x86_64": @@ -857,10 +853,8 @@ def get_kind(self, name: str): return ("value", resp) def dir(self): - """Return list of public attribute names if worker supports dir RPC; else empty list.""" - if getattr(self, "_cap_dir", False): - return self._send_rpc("dir", {}) - return [] + """Return list of public attribute names.""" + return self._send_rpc("dir", {}) def set(self, name: str, value: Any): logger.debug(f"Setting attribute: {name}") @@ -1316,28 +1310,15 @@ def __getattr__(self, name: str): def _remote_method(*args, **kwargs): return self._rpc("call", name=name, args=args, kwargs=kwargs) - # Non-exceptional discovery using get-kind if available - try: - if self._wp and getattr(self._wp, "_cap_get_kind", False): - kind, payload = self._wp.get_kind(name) - if kind == "value": - return payload - if kind == "callable": - return _remote_method - if kind == "missing": - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") - raise Tempo2ProtocolError(f"unexpected get-kind '{kind}' for attribute '{name}'") - except Exception: - pass - - # Legacy fallback: try get and assume callable on failure - try: - val = self._rpc("get", name=name) - if not callable(val): - return val - return _remote_method - except (Tempo2Error, Tempo2Timeout, Tempo2Crashed, Tempo2ProtocolError): + # Non-exceptional discovery using get-kind + kind, payload = self._wp.get_kind(name) + if kind == "value": + return payload + if kind == "callable": return _remote_method + if kind == "missing": + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + raise Tempo2ProtocolError(f"unexpected get-kind '{kind}' for attribute '{name}'") def __setattr__(self, name: str, value: Any): if name in tempopulsar.__slots__: @@ -1348,8 +1329,7 @@ def __setattr__(self, name: str, value: Any): def __dir__(self): """Return a list of available attributes for dir() function.""" try: - if self._wp and getattr(self._wp, "_cap_dir", False): - return list(self._wp.dir()) + return list(self._wp.dir()) except Exception: pass @@ -1424,34 +1404,18 @@ def __init__(self, parent: tempopulsar, name: str) -> None: def __getattr__(self, attr: str): # Fetch field via dotted get path (e.g., RAJ.val), honoring get-kind - try: - if self._parent._wp and getattr(self._parent._wp, "_cap_get_kind", False): - kind, payload = self._parent._wp.get_kind(f"{self._name}.{attr}") - if kind == "value": - return payload - if kind == "callable": - # Expose a callable that routes via call with dotted name - def _remote_method(*args, **kwargs): - return self._parent._rpc("call", name=f"{self._name}.{attr}", args=args, kwargs=kwargs) - - return _remote_method - if kind == "missing": - raise AttributeError(f"'{self._name}' has no attribute '{attr}'") - except Exception: - pass - # Legacy fallback: direct get - resp = self._parent._rpc("get", name=f"{self._name}.{attr}") - if isinstance(resp, dict) and "kind" in resp: - if resp["kind"] == "value": - return resp.get("value") - if resp["kind"] == "callable": - - def _remote_method(*args, **kwargs): - return self._parent._rpc("call", name=f"{self._name}.{attr}", args=args, kwargs=kwargs) + kind, payload = self._parent._wp.get_kind(f"{self._name}.{attr}") + if kind == "value": + return payload + if kind == "callable": + # Expose a callable that routes via call with dotted name + def _remote_method(*args, **kwargs): + return self._parent._rpc("call", name=f"{self._name}.{attr}", args=args, kwargs=kwargs) - return _remote_method + return _remote_method + if kind == "missing": raise AttributeError(f"'{self._name}' has no attribute '{attr}'") - return resp + raise Tempo2ProtocolError(f"unexpected get-kind '{kind}' for attribute '{self._name}.{attr}'") def __setattr__(self, attr: str, value: Any) -> None: if attr in _ParamProxy.__slots__: @@ -1479,30 +1443,16 @@ def __init__(self, parent: tempopulsar, name: str) -> None: def __array__(self, dtype=None): import numpy as _np - arr = None - try: - # Prefer get-kind if supported to avoid unnecessary data shapes - if self._parent._wp and getattr(self._parent._wp, "_cap_get_kind", False): - kind, payload = self._parent._wp.get_kind(self._name) - if kind == "value": - arr = payload - elif kind == "callable": - # Arrays are not callable; treat as empty - arr = [] - else: - arr = [] - else: - resp = self._parent._rpc("get", name=self._name) - if isinstance(resp, dict) and "kind" in resp: - arr = resp.get("value") - else: - arr = resp - except Exception: - resp = self._parent._rpc("get", name=self._name) - if isinstance(resp, dict) and "kind" in resp: - arr = resp.get("value") - else: - arr = resp + # Use get-kind to get the array data + kind, payload = self._parent._wp.get_kind(self._name) + if kind == "value": + arr = payload + elif kind == "callable": + # Arrays are not callable; treat as empty + arr = [] + else: + arr = [] + a = _np.asarray(arr) if dtype is not None: a = a.astype(dtype, copy=False) @@ -1522,12 +1472,7 @@ def dtype(self): # python indexing def __getitem__(self, idx): - try: - if self._parent._wp and getattr(self._parent._wp, "_cap_get_slice", False): - return self._parent._wp.get_slice(self._name, idx) - except Exception: - pass - return self.__array__()[idx] + return self._parent._wp.get_slice(self._name, idx) def __setitem__(self, idx, value): _ = self._parent._rpc("setitem", name=self._name, index=idx, value=value) From c1e675540cbc144280e30cb67c21d1fa048d7c5a Mon Sep 17 00:00:00 2001 From: Rutger van Haasteren Date: Sun, 12 Oct 2025 22:16:00 +0200 Subject: [PATCH 13/16] fix: sandbox did not have state management. Now it restores state after recycle --- libstempo/sandbox.py | 108 +++++++++++++++++++++++- tests/test_sandbox.py | 192 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 296 insertions(+), 4 deletions(-) diff --git a/libstempo/sandbox.py b/libstempo/sandbox.py index 8360eb9..29dacb5 100644 --- a/libstempo/sandbox.py +++ b/libstempo/sandbox.py @@ -606,10 +606,10 @@ def _drain_stderr(pipe, sink_deque, sink_file): for line in iter(pipe.readline, ""): line = line.rstrip("\n") sink_deque.append(line) - + # Write to real stderr for real-time output (native-like behavior) print(line, file=sys.stderr, flush=True) - + if sink_file: try: sink_file.write(line + "\n") @@ -1048,6 +1048,16 @@ class _State: created_at: float calls_ok: int + # State cache for crash recovery + param_cache: Dict[str, Dict[str, Any]] = dataclasses.field( + default_factory=dict + ) # {'RAJ': {'val': 5.016, 'fit': True}, ...} + array_cache: Dict[str, Any] = dataclasses.field( + default_factory=dict + ) # {'stoas': modified_array, 'toaerrs': modified_array} + # Crash recovery statistics + crash_count: int = 0 + last_crash_at: Optional[float] = None class tempopulsar: @@ -1143,10 +1153,18 @@ def _construct_with_retries(self): self._state.created_at = time.time() self._state.calls_ok = 0 logger.info(f"Construction successful on attempt {attempt + 1}") + + # Restore state after successful reconstruction + self._restore_state_after_reconstruction() return except Exception as e: logger.warning(f"Construction attempt {attempt + 1} failed: {e}") last_exc = e + + # Record crash if this is a retry (not the first attempt) + if attempt > 0: + self._record_crash() + # If it's a file-not-found style error, fail fast without retries msg = str(e) if any( @@ -1204,6 +1222,62 @@ def _proactive_nobs_setup(self): logger.warning(f"Proactive nobs setup failed: {e}") # Don't raise - this is just optimization, construction should still work + # ----------------------------- state management ----------------------------- # + + def _capture_param_state(self, param_name: str, field: str, value: Any) -> None: + """Capture parameter state for crash recovery.""" + if param_name not in self._state.param_cache: + self._state.param_cache[param_name] = {} + self._state.param_cache[param_name][field] = value + logger.debug(f"Captured param state: {param_name}.{field} = {value}") + + def _capture_array_state(self, array_name: str, value: Any) -> None: + """Capture array state for crash recovery.""" + self._state.array_cache[array_name] = value + logger.debug(f"Captured array state: {array_name}") + + def _restore_state_after_reconstruction(self) -> None: + """Restore parameter values, fit flags, and array modifications after worker reconstruction.""" + if not self._state.param_cache and not self._state.array_cache: + logger.debug("No state to restore") + return + + logger.info(f"Restoring state: {len(self._state.param_cache)} params, {len(self._state.array_cache)} arrays") + + # Restore parameter values and fit flags + for param_name, param_state in self._state.param_cache.items(): + for field, value in param_state.items(): + try: + self._wp.set(f"{param_name}.{field}", value) + logger.debug(f"Restored param: {param_name}.{field} = {value}") + except Exception as e: + logger.warning(f"Failed to restore param {param_name}.{field}: {e}") + + # Restore array modifications + for array_name, array_data in self._state.array_cache.items(): + try: + self._wp.setitem(array_name, slice(None), array_data) + logger.debug(f"Restored array: {array_name}") + except Exception as e: + logger.warning(f"Failed to restore array {array_name}: {e}") + + logger.info("State restoration completed") + + def _record_crash(self) -> None: + """Record crash statistics.""" + self._state.crash_count += 1 + self._state.last_crash_at = time.time() + logger.info(f"Worker crash recorded (total crashes: {self._state.crash_count})") + + def get_crash_stats(self) -> Dict[str, Any]: + """Get crash recovery statistics.""" + return { + "crash_count": self._state.crash_count, + "last_crash_at": self._state.last_crash_at, + "worker_age_s": time.time() - self._state.created_at if self._wp else None, + "calls_since_creation": self._state.calls_ok, + } + # ----------------------------- recycling policy --------------------------- # def _should_recycle(self) -> bool: @@ -1275,7 +1349,8 @@ def _rpc(self, call: str, **payload): if call != "get" or not str(e).startswith("AttributeError"): logger.warning(f"RPC {call} failed with {type(e).__name__}: {e}") logger.info("Attempting automatic worker recycle and retry") - # automatic one-time recycle on a fresh worker + # Record crash and recycle + self._record_crash() self._recycle() assert self._wp is not None if call == "get": @@ -1323,6 +1398,10 @@ def _remote_method(*args, **kwargs): def __setattr__(self, name: str, value: Any): if name in tempopulsar.__slots__: return object.__setattr__(self, name, value) + + # Capture state for crash recovery + self._capture_array_state(name, value) + _ = self._rpc("set", name=name, value=value) return None @@ -1420,6 +1499,10 @@ def _remote_method(*args, **kwargs): def __setattr__(self, attr: str, value: Any) -> None: if attr in _ParamProxy.__slots__: return object.__setattr__(self, attr, value) + + # Capture parameter state for crash recovery + self._parent._capture_param_state(self._name, attr, value) + _ = self._parent._rpc("set", name=f"{self._name}.{attr}", value=value) return None @@ -1452,7 +1535,7 @@ def __array__(self, dtype=None): arr = [] else: arr = [] - + a = _np.asarray(arr) if dtype is not None: a = a.astype(dtype, copy=False) @@ -1475,6 +1558,23 @@ def __getitem__(self, idx): return self._parent._wp.get_slice(self._name, idx) def __setitem__(self, idx, value): + # Capture array state for crash recovery + # For full array replacement (slice(None)), capture the entire array + if idx == slice(None): + self._parent._capture_array_state(self._name, value) + else: + # For partial updates, we need to get the current array and apply the change + # This is more complex, so for now we'll capture the full array after the change + try: + current_array = self.__array__() + if hasattr(current_array, "copy"): + new_array = current_array.copy() + new_array[idx] = value + self._parent._capture_array_state(self._name, new_array) + except Exception: + # If we can't capture the state, continue anyway + pass + _ = self._parent._rpc("setitem", name=self._name, index=idx, value=value) return None diff --git a/tests/test_sandbox.py b/tests/test_sandbox.py index 8a0dd2e..7746cb8 100644 --- a/tests/test_sandbox.py +++ b/tests/test_sandbox.py @@ -180,6 +180,198 @@ def test_stoas_edit_and_fit_matches_native(self): assert_allclose(res_s, res_n, rtol=0, atol=0) +class TestStateManagement(unittest.TestCase): + """Tests for state management and crash recovery.""" + + @classmethod + def setUpClass(cls): + cls.data_path = t2.__path__[0] + "/data/" + cls.parfile = cls.data_path + "J1909-3744_NANOGrav_dfg+12.par" + cls.timfile = cls.data_path + "J1909-3744_NANOGrav_dfg+12.tim" + + def test_param_state_capture(self): + """Test that parameter modifications are captured in state cache.""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile) + + # Modify parameters + original_raj = psr["RAJ"].val + psr["RAJ"].val = original_raj + 0.001 + psr["DM"].fit = False + + # Check state cache + self.assertIn("RAJ", psr._state.param_cache) + self.assertIn("val", psr._state.param_cache["RAJ"]) + self.assertEqual(psr._state.param_cache["RAJ"]["val"], original_raj + 0.001) + self.assertEqual(psr._state.param_cache["DM"]["fit"], False) + + def test_array_state_capture(self): + """Test that array modifications are captured in state cache.""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile) + + # Modify arrays + original_stoas = psr.stoas.copy() + psr.stoas[:] = original_stoas + 1e-6 + + # Check state cache + self.assertIn("stoas", psr._state.array_cache) + expected_stoas = original_stoas + 1e-6 + np.testing.assert_allclose(psr._state.array_cache["stoas"], expected_stoas) + + def test_state_restoration_after_recycle(self): + """Test that state is restored after worker recycle.""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile) + + # Modify state + psr["RAJ"].val = 5.020 + psr["DM"].fit = True + original_stoas = psr.stoas.copy() + psr.stoas[:] = original_stoas + 2e-6 + + # Force recycle + psr._recycle() + + # Verify restoration + self.assertAlmostEqual(psr["RAJ"].val, 5.020, places=6) + self.assertTrue(psr["DM"].fit) + np.testing.assert_allclose(psr.stoas, original_stoas + 2e-6, rtol=1e-10) + + def test_crash_statistics_tracking(self): + """Test crash statistics tracking.""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile) + + # Initial stats + stats = psr.get_crash_stats() + self.assertEqual(stats["crash_count"], 0) + self.assertIsNone(stats["last_crash_at"]) + + # Record crash manually + psr._record_crash() + + # Check stats + stats = psr.get_crash_stats() + self.assertEqual(stats["crash_count"], 1) + self.assertIsNotNone(stats["last_crash_at"]) + + # Record another crash + psr._record_crash() + stats = psr.get_crash_stats() + self.assertEqual(stats["crash_count"], 2) + + def test_crash_stats_after_recycle(self): + """Test crash statistics after recycle.""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile) + + # Record crash and recycle + psr._record_crash() + psr._recycle() + + # Check stats + stats = psr.get_crash_stats() + self.assertEqual(stats["crash_count"], 1) + self.assertIsNotNone(stats["last_crash_at"]) + self.assertGreater(stats["worker_age_s"], 0) + + def test_complete_crash_recovery_workflow(self): + """Test complete workflow: modify -> crash -> recover -> verify.""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile) + + # Step 1: Modify state + original_raj = psr["RAJ"].val + original_dm_fit = psr["DM"].fit + original_stoas = psr.stoas.copy() + + psr["RAJ"].val = original_raj + 0.005 + psr["DM"].fit = not original_dm_fit + psr.stoas[:] = original_stoas + 5e-6 + + # Step 2: Simulate crash and recovery + psr._record_crash() + psr._recycle() + + # Step 3: Verify state restoration + self.assertAlmostEqual(psr["RAJ"].val, original_raj + 0.005, places=6) + self.assertEqual(psr["DM"].fit, not original_dm_fit) + np.testing.assert_allclose(psr.stoas, original_stoas + 5e-6, rtol=1e-10) + + # Step 4: Verify crash stats + stats = psr.get_crash_stats() + self.assertEqual(stats["crash_count"], 1) + self.assertIsNotNone(stats["last_crash_at"]) + + def test_state_preservation_across_multiple_crashes(self): + """Test state preservation across multiple crashes.""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile) + + # Initial modifications + psr["RAJ"].val = 5.025 + psr["DM"].fit = False + + # Multiple crashes and recoveries + for i in range(3): + psr._record_crash() + psr._recycle() + + # Verify state preserved + self.assertAlmostEqual(psr["RAJ"].val, 5.025, places=6) + self.assertFalse(psr["DM"].fit) + + # Check final crash count + stats = psr.get_crash_stats() + self.assertEqual(stats["crash_count"], 3) + + def test_state_restoration_with_invalid_parameters(self): + """Test state restoration when some parameters are invalid.""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile) + + # Modify valid parameter + psr["RAJ"].val = 5.030 + + # Add invalid parameter to cache (simulate edge case) + psr._state.param_cache["INVALID_PARAM"] = {"val": 999.0} + + # Recycle and verify valid parameter restored + psr._recycle() + self.assertAlmostEqual(psr["RAJ"].val, 5.030, places=6) + + def test_empty_state_cache_restoration(self): + """Test restoration when state cache is empty.""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile) + + # Ensure empty cache + psr._state.param_cache.clear() + psr._state.array_cache.clear() + + # Recycle should not fail + psr._recycle() + + # Basic functionality should still work + self.assertEqual(psr.name, "1909-3744") + + def test_state_capture_performance(self): + """Test that state capture doesn't significantly impact performance.""" + import time + + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile) + + # Time parameter modifications + start = time.time() + for i in range(100): + psr["RAJ"].val = 5.0 + i * 0.001 + param_time = time.time() - start + + # Time array modifications + start = time.time() + for i in range(10): + # Get the array, modify it, and set it back + current_stoas = psr.stoas.copy() + psr.stoas[:] = current_stoas + i * 1e-6 + array_time = time.time() - start + + # Should be reasonable (adjust thresholds as needed) + self.assertLess(param_time, 5.0) # 100 param changes in < 5 seconds + self.assertLess(array_time, 10.0) # 10 array changes in < 10 seconds + + class TestTimFileAnalyzer(unittest.TestCase): @classmethod def setUpClass(cls): From 95a094fadcf32911fec93a21ba49ff85764f727d Mon Sep 17 00:00:00 2001 From: Rutger van Haasteren Date: Mon, 13 Oct 2025 20:18:37 +0000 Subject: [PATCH 14/16] fix: pickling error for unpickleable psr attributes, solved by adding safe serialization --- libstempo/sandbox.py | 73 ++++++++++++++++++++++++++++++++++++++++++- tests/test_sandbox.py | 52 ++++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+), 1 deletion(-) diff --git a/libstempo/sandbox.py b/libstempo/sandbox.py index 29dacb5..7786f55 100644 --- a/libstempo/sandbox.py +++ b/libstempo/sandbox.py @@ -386,8 +386,76 @@ def _write_response(resp: Dict[str, Any]) -> None: except Exception: pass + # Safely serialize value; some libstempo/Boost.Python objects are not picklable + try: + _ = _b64_dumps_py({"kind": "value", "value": cur}) + result_payload = {"kind": "value", "value": cur} + except Exception: + # Best-effort conversion for libstempo parameter-like objects + safe_value = None + try: + # Detect param-like structures (e.g., RAJ/DEC) and extract primitives + has_val = hasattr(cur, "val") or hasattr(cur, "_val") + has_err = hasattr(cur, "err") or hasattr(cur, "_err") + has_fit = hasattr(cur, "fit") or hasattr(cur, "fitFlag") or hasattr(cur, "_fitFlag") + if has_val or has_err or has_fit: + val = None + err = None + fit = None + with contextlib.suppress(Exception): + v = getattr(cur, "val", getattr(cur, "_val", None)) + # Convert numpy scalars to Python + try: + import numpy as _np2 + + if isinstance(v, _np2.generic): + v = v.item() + except Exception: + pass + val = v + with contextlib.suppress(Exception): + e = getattr(cur, "err", getattr(cur, "_err", None)) + try: + import numpy as _np2 + + if isinstance(e, _np2.generic): + e = e.item() + except Exception: + pass + err = e + with contextlib.suppress(Exception): + f = getattr(cur, "fit", None) + if f is None: + f = getattr(cur, "fitFlag", getattr(cur, "_fitFlag", None)) + # Normalize to bool when possible + if isinstance(f, (int, bool)): + fit = bool(f) + else: + fit = f + name_guess = None + with contextlib.suppress(Exception): + name_guess = getattr(cur, "name", None) or getattr(cur, "label", None) + safe_value = { + "__libstempo_param__": True, + "name": name_guess, + "val": val, + "err": err, + "fit": fit, + } + else: + # Fallback to repr string if completely opaque + safe_value = {"__repr__": repr(cur)} + except Exception: + safe_value = {"__repr__": repr(cur)} + + result_payload = {"kind": "value", "value": safe_value} + _write_response( - {"jsonrpc": "2.0", "id": rid, "result_b64": _b64_dumps_py({"kind": "value", "value": cur})} + { + "jsonrpc": "2.0", + "id": rid, + "result_b64": _b64_dumps_py(result_payload), + } ) continue @@ -1388,6 +1456,9 @@ def _remote_method(*args, **kwargs): # Non-exceptional discovery using get-kind kind, payload = self._wp.get_kind(name) if kind == "value": + # If worker returned a safe libstempo param marker, expose a proxy + if isinstance(payload, dict) and payload.get("__libstempo_param__"): + return _ParamProxy(self, name) return payload if kind == "callable": return _remote_method diff --git a/tests/test_sandbox.py b/tests/test_sandbox.py index 7746cb8..9c345ac 100644 --- a/tests/test_sandbox.py +++ b/tests/test_sandbox.py @@ -179,6 +179,58 @@ def test_stoas_edit_and_fit_matches_native(self): self.assertEqual(res_s.shape, res_n.shape) assert_allclose(res_s, res_n, rtol=0, atol=0) + def test_param_attribute_proxy_no_pickling_and_roundtrip(self): + """Access psr.RAJ attribute (check pickling error) and roundtrip fields.""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile, dofit=False) + + # This attribute access used to force pickling of a non-picklable param object. + p = psr.RAJ + + # Read primitives + v0 = float(p.val) + e0 = float(p.err) + f0 = bool(p.fit) + + # Sanity on types + self.assertIsInstance(v0, float) + self.assertIsInstance(e0, float) + self.assertIsInstance(f0, bool) + + # Roundtrip same values (ensures proxy->worker set path works) + p.val = v0 + self.assertAlmostEqual(float(psr.RAJ.val), v0, places=12) + + p.err = e0 + self.assertAlmostEqual(float(psr.RAJ.err), e0, places=12) + + # Toggle fit and revert + p.fit = not f0 + self.assertEqual(bool(psr.RAJ.fit), (not f0)) + p.fit = f0 + self.assertEqual(bool(psr.RAJ.fit), f0) + + # Proxy should be printable + _ = repr(p) + _ = str(p) + + def test_param_attribute_vs_mapping_consistency(self): + """Ensure attribute-style psr.RAJ and mapping psr['RAJ'] remain consistent.""" + psr = tempopulsar(parfile=self.parfile, timfile=self.timfile, dofit=False) + + # Values and errors agree between attribute and mapping APIs + self.assertAlmostEqual(float(psr.RAJ.val), float(psr["RAJ"].val), places=12) + self.assertAlmostEqual(float(psr.RAJ.err), float(psr["RAJ"].err), places=12) + + # Setting via attribute reflects in mapping + new_val = float(psr.RAJ.val) + psr.RAJ.val = new_val + self.assertAlmostEqual(float(psr["RAJ"].val), new_val, places=12) + + # Setting via mapping reflects in attribute + new_err = float(psr["RAJ"].err) + psr["RAJ"].err = new_err + self.assertAlmostEqual(float(psr.RAJ.err), new_err, places=12) + class TestStateManagement(unittest.TestCase): """Tests for state management and crash recovery.""" From 268f0d1c3221ff06c7f3d6db068772d1b0645ad4 Mon Sep 17 00:00:00 2001 From: Rutger van Haasteren Date: Mon, 13 Oct 2025 20:41:08 +0000 Subject: [PATCH 15/16] fix: more robust b64 decoding by first checking for encode attribute --- libstempo/sandbox.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/libstempo/sandbox.py b/libstempo/sandbox.py index 7786f55..9cfa7f6 100644 --- a/libstempo/sandbox.py +++ b/libstempo/sandbox.py @@ -167,7 +167,11 @@ def _b64_dumps_py(obj: Any) -> str: def _b64_loads_py(s: str) -> Any: """Deserialize base64-encoded string to Python object using cloudpickle.""" - return _cp.loads(base64.b64decode(s.encode("ascii"))) + if hasattr(s, 'encode'): + s_str = s + else: + s_str = str(s) + return _cp.loads(base64.b64decode(s_str.encode("ascii"))) def _format_exc_tuple() -> Tuple[str, str, str]: From 2841a4d9a630c2dee6ad3253097168156e9a4eea Mon Sep 17 00:00:00 2001 From: Rutger van Haasteren Date: Tue, 2 Dec 2025 15:07:17 +0000 Subject: [PATCH 16/16] fix: RPC-JSON communication was polluted by warning when astropy is not installed. This fixes that bug. --- libstempo/sandbox.py | 66 ++++++++++++++++++++++++++++---------------- 1 file changed, 42 insertions(+), 24 deletions(-) diff --git a/libstempo/sandbox.py b/libstempo/sandbox.py index 9cfa7f6..18e621b 100644 --- a/libstempo/sandbox.py +++ b/libstempo/sandbox.py @@ -54,10 +54,10 @@ - Error details and recovery attempts Robustness: - The sandbox suppresses libstempo debug output during construction - to prevent interference with the JSON-RPC protocol. This ensures reliable - communication even when libstempo prints diagnostic messages. The suppression - works at the OS file descriptor level to catch output from C libraries. + The sandbox redirects stdout to stderr in worker processes before importing libstempo, + preventing interference with the JSON-RPC protocol. This ensures reliable + communication even when libstempo or other libraries print diagnostic messages. + The redirection works at the OS file descriptor level to catch output from C libraries. """ from __future__ import annotations @@ -211,14 +211,32 @@ def _worker_stdio_main() -> None: Methods: ctor, get, set, call, del, rss, bye Each request's 'params_b64' is a pickled dict of parameters. Each response uses 'result_b64' for Python results, or 'error'. + + To prevent interference with the JSON-RPC protocol, stdout is redirected to stderr + before importing libstempo. This ensures clean communication even when libstempo + prints diagnostic messages. The redirection works at the OS file descriptor level. """ - # Permanently redirect C-level stdout (FD 1) to stderr (FD 2), - # while keeping JSON-RPC on a dedicated duplicate of the original stdout pipe. import os as _os_for_fds - _proto_fd = _os_for_fds.dup(1) # save original stdout FD for protocol - _os_for_fds.dup2(2, 1) # route any C/printf stdout to stderr - sys.stdout = _os_for_fds.fdopen(_proto_fd, "w", buffering=1) + # Check if redirection was already set up by the subprocess command + _proto_out = None # type: ignore + _env_fd = os.environ.get("TEMPO2_SANDBOX_PROTO_FD") + if _env_fd is not None: + try: + _proto_fd = int(_env_fd) + _proto_out = _os_for_fds.fdopen(_proto_fd, "w", buffering=1) + except Exception: + _proto_out = None + finally: + # Remove the hint to avoid leaking to children + with contextlib.suppress(Exception): + os.environ.pop("TEMPO2_SANDBOX_PROTO_FD", None) + if _proto_out is None: + # Fallback: perform redirection here if not done in command + _proto_fd = _os_for_fds.dup(1) # save original stdout FD for protocol + _os_for_fds.dup2(2, 1) # route any C/printf stdout to stderr + sys.stdout = sys.stderr # route Python-level prints to stderr + _proto_out = _os_for_fds.fdopen(_proto_fd, "w", buffering=1) # Step 1: hello handshake hello = { @@ -249,8 +267,8 @@ def _worker_stdio_main() -> None: except Exception: pass finally: - sys.stdout.write(json.dumps(hello) + "\n") - sys.stdout.flush() + _proto_out.write(json.dumps(hello) + "\n") + _proto_out.flush() # If libstempo failed to import at hello, try once more here to return clean errors try: @@ -264,8 +282,8 @@ def _worker_stdio_main() -> None: def _write_response(resp: Dict[str, Any]) -> None: """Write JSON response to stdout and flush.""" - sys.stdout.write(json.dumps(resp) + "\n") - sys.stdout.flush() + _proto_out.write(json.dumps(resp) + "\n") + _proto_out.flush() # JSON-RPC loop for line in sys.stdin: @@ -1037,7 +1055,7 @@ def _resolve_worker_cmd(env_name: Optional[str]) -> Tuple[List[str], bool]: Returns (cmd, require_x86_64) """ - # Base invocation that runs this file in worker mode: + # Base invocation that runs this file in worker mode. # Find the src directory dynamically current_file = Path(__file__).resolve() src_dir = current_file.parent.parent # Go up from libstempo/sandbox.py to src/ @@ -1048,7 +1066,15 @@ def python_to_worker_cmd(python_exe: str) -> List[str]: return [ python_exe, "-c", - f"import sys; sys.path.insert(0, '{src_path}'); import libstempo.sandbox as m; m._worker_stdio_main()", + ( + f"import os, sys; " + f"os.environ['TEMPO2_SANDBOX_PROTO_FD'] = str(os.dup(1)); " + f"os.dup2(2, 1); " + f"sys.stdout = sys.stderr; " + f"sys.path.insert(0, '{src_path}'); " + f"import libstempo.sandbox as m; " + f"m._worker_stdio_main()" + ), ] arch_prefix_env = os.environ.get("TEMPO2_SANDBOX_WORKER_ARCH_PREFIX", "").strip() @@ -1069,15 +1095,7 @@ def python_to_worker_cmd(python_exe: str) -> List[str]: # conda/mamba/micromamba if etype.startswith("conda:"): tool = etype.split(":", 1)[1] - cmd = [ - tool, - "run", - "-n", - env_name, - "python", - "-c", - f"import sys; sys.path.insert(0, '{src_path}'); import libstempo.sandbox as m; m._worker_stdio_main()", - ] + cmd = [tool, "run", "-n", env_name] + python_to_worker_cmd("python") # Choosing to require x86_64 only if user *explicitly* asks via arch prefix or env_name == "arch" require_x86_64 = "arch" in env_name.lower() if arch_prefix_env: