From 82a3ca63c9bb08c6d633d56fc15d45e73781db2d Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 8 Jan 2026 12:32:11 +1300 Subject: [PATCH 01/12] feat: minimal type checking changes --- .github/workflows/types.yml | 32 ++++++ pyproject.toml | 2 +- qcore/archive_structure.py | 52 ++++++++-- qcore/cli.py | 50 ++++++---- qcore/constants.py | 30 +++--- qcore/coordinates.py | 59 +++++------ qcore/formats.py | 89 ++++++++++------- qcore/geo.py | 13 ++- qcore/grid.py | 14 ++- qcore/nhm.py | 12 ++- qcore/point_in_polygon.py | 23 +++-- qcore/shared.py | 30 +++--- qcore/simulation_structure.py | 49 +++------ qcore/siteamp_models.py | 22 +++-- qcore/src_site_dist.py | 90 +++++++---------- qcore/timeseries.py | 93 +++-------------- qcore/typing.py | 8 ++ qcore/uncertainties/distributions.py | 143 +++++++++++++++++++++++---- qcore/utils.py | 98 +----------------- qcore/xyts.py | 71 ++++++++++--- tests/test_coordinates.py | 31 ++++-- tests/test_geo.py | 77 +++++++++++---- tests/test_xyts.py | 1 - 23 files changed, 610 insertions(+), 479 deletions(-) create mode 100644 .github/workflows/types.yml create mode 100644 qcore/typing.py diff --git a/.github/workflows/types.yml b/.github/workflows/types.yml new file mode 100644 index 00000000..0f7a1c85 --- /dev/null +++ b/.github/workflows/types.yml @@ -0,0 +1,32 @@ +name: Type Check + +on: + push: + branches: [master] + pull_request: + branches: [master] + +jobs: + typecheck: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.13" + + + - name: Install uv + uses: astral-sh/setup-uv@v5 + with: + enable-cache: true + + - name: Install project with types + run: uv sync --all-extras --dev + + - name: Run type checking with ty + run: uv run ty check diff --git a/pyproject.toml b/pyproject.toml index f64629d7..d10a92b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,6 @@ dependencies = [ "pyproj", "pooch", "scipy", - "matplotlib", "numpy", "numba>=0.57.0; python_version == '3.11'", "numba>=0.59.0; python_version == '3.12'", @@ -25,6 +24,7 @@ dependencies = [ "docstring_parser", "xarray", "pyyaml", + "typing-extensions>=4.9.0", # Last change to the @deprecated module ] [project.optional-dependencies] diff --git a/qcore/archive_structure.py b/qcore/archive_structure.py index 8f4abb90..14cab287 100644 --- a/qcore/archive_structure.py +++ b/qcore/archive_structure.py @@ -1,27 +1,61 @@ """ Gives access to the folder structure of archived cybershake directories """ + from pathlib import Path from .simulation_structure import get_fault_from_realisation -def get_fault_source_dir(fault_dir: Path): - """Gets the Source directory for the given fault directory""" +def get_fault_source_dir(fault_dir: Path) -> Path: + """ + Get the Source directory for a given fault directory. + + Parameters + ---------- + fault_dir : Path + Path to the fault directory. + + Returns + ------- + Path + Path to the Source directory within the given fault directory. + """ return fault_dir / "Source" -def get_fault_im_dir(fault_dir: Path): - """Gets the IM directory for the given fault directory""" +def get_fault_im_dir(fault_dir: Path) -> Path: + """ + Get the IM directory for a given fault directory. + + Parameters + ---------- + fault_dir : Path + Path to the fault directory. + + Returns + ------- + Path + Path to the IM directory within the given fault directory. + """ return fault_dir / "IM" -def get_fault_bb_dir(fault_dir: Path): - """Gets the BB directory for the given fault directory""" - return fault_dir / "BB" +def get_IM_csv_from_root(archive_root: Path, realisation: str) -> Path: # noqa: N802 + """ + Get the full path to the IM CSV file given the archive root and realisation name. + Parameters + ---------- + archive_root : Path + Path to the root directory of the Cybershake archive. + realisation : str + Name of the realisation to locate. -def get_IM_csv_from_root(archive_root: Path, realisation: str): - """Gets the full path to the im_csv file given the archive root dir and the realistion name""" + Returns + ------- + Path + Full path to the IM CSV file for the specified realisation. + """ fault_name = get_fault_from_realisation(realisation) return get_fault_im_dir(archive_root / fault_name) / f"{realisation}.csv" diff --git a/qcore/cli.py b/qcore/cli.py index 63aeaf65..ec968a33 100644 --- a/qcore/cli.py +++ b/qcore/cli.py @@ -3,33 +3,48 @@ import inspect from collections.abc import Callable from functools import wraps -from typing import Annotated, Any, get_args, get_origin +from typing import ( + Annotated, + Any, + ParamSpec, + TypeVar, + get_args, + get_origin, +) import docstring_parser import typer from docstring_parser.common import DocstringStyle +from typer.models import ArgumentInfo, OptionInfo +# P captures the parameters (args and kwargs) of the decorated function. +P = ParamSpec("P") -# Originally written by @Genfood: https://github.com/fastapi/typer/issues/336#issuecomment-2434726193 -# Updated and modified for Python 3.13. -def from_docstring(app: typer.Typer, **kwargs: dict) -> Callable: +# R captures the return type of the decorated function. +R = TypeVar("R") + + +def from_docstring( + app: typer.Typer, + **kwargs: Any, +) -> Callable[[Callable[P, R]], Callable[P, R]]: """Apply help texts from the function's docstring to Typer arguments/options and command. Parameters ---------- app : typer.Typer The Typer application to which the command will be registered. - **kwargs : dict + **kwargs : Any Additional keyword arguments to be passed to the Typer command. Returns ------- - Callable - The decorated function with help texts applied, without overwriting - existing settings. + Callable[[Callable[P, R]], Callable[P, R]] + A decorator function that takes a command (Callable[P, R]) and + returns a wrapper (Callable[P, R]) preserving its signature P and return R. """ - def decorator(command: Callable) -> Callable: + def decorator(command: Callable[P, R]) -> Callable[P, R]: # numpydoc ignore=GL08 if command.__doc__ is None: return command @@ -66,9 +81,7 @@ def decorator(command: Callable) -> Callable: param_type, *metadata = get_args(param_type) new_metadata = [] for m in metadata: - if isinstance( - m, typer.models.ArgumentInfo | typer.models.OptionInfo - ): + if isinstance(m, ArgumentInfo | OptionInfo): if not m.help: m.help = help_text new_metadata.append(m) @@ -77,9 +90,7 @@ def decorator(command: Callable) -> Callable: ) # If it's an Option or Argument directly - elif isinstance( - param.default, (typer.models.ArgumentInfo, typer.models.OptionInfo) - ): + elif isinstance(param.default, ArgumentInfo | OptionInfo): if not param.default.help: param.default.help = help_text new_param = param @@ -103,15 +114,18 @@ def decorator(command: Callable) -> Callable: # Create a new signature with updated parameters new_sig = sig.replace(parameters=new_parameters) - # Apply the new signature to the wrapper function - # Register the command with the app + # Since the signature (P, R) is applied to the decorator result, + # the wrapper's type definition must match what command returns (R). @app.command(help=command_help.strip(), **kwargs) @wraps(command) - def wrapper(*args: Any, **kwargs: Any) -> Any: # numpydoc ignore=GL08 + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: # numpydoc ignore=GL08 return command(*args, **kwargs) + # NOTE: Typer requires the dynamic signature update for runtime reflection, + # but the type checker uses the P, R generics. wrapper.__signature__ = new_sig + return wrapper return decorator diff --git a/qcore/constants.py b/qcore/constants.py index 7516a028..77b0da01 100644 --- a/qcore/constants.py +++ b/qcore/constants.py @@ -1,35 +1,43 @@ """DEPRECATED - Global constants and Enum helper.""" +from collections.abc import Generator from enum import Enum from typing import Any -from warnings import deprecated # type: ignore +from typing_extensions import deprecated # type: ignore -@deprecated + +@deprecated("Use built-in Enum") class ExtendedEnum(Enum): """DEPRECATED: Utility enum extension. Use built-in Enum.""" @classmethod - def has_value(cls, value): + def has_value(cls, value: Any) -> bool: return any(value == item.value for item in cls) @classmethod - def is_substring(cls, parent_string): + def is_substring(cls, parent_string: str) -> bool: """Check if an enum's string value is contained in the given string""" - return any(item.value in parent_string for item in cls) + return any( + not isinstance(item.value, str) or item.value in parent_string + for item in cls + ) @classmethod - def get_names(cls): + def get_names(cls) -> list[str]: return [item.name for item in cls] - def __str__(self): + def __str__(self) -> str: return self.name -@deprecated -class ExtendedStrEnum(ExtendedEnum): +@deprecated("Use built-in StrEnum") +class ExtendedStrEnum(ExtendedEnum): # type: ignore """DEPRECATED: Utility Enum extension for string mappings. Use built-in StrEnum.""" + _value_: Any + str_value: str + def __new__(cls, value: Any, str_value: str): # noqa: D102 # numpydoc ignore=GL08 obj = object.__new__(cls) obj._value_ = value @@ -37,7 +45,7 @@ def __new__(cls, value: Any, str_value: str): # noqa: D102 # numpydoc ignore=GL return obj @classmethod - def has_str_value(cls, str_value): + def has_str_value(cls, str_value: str) -> bool: return any(str_value == item.str_value for item in cls) @classmethod @@ -50,7 +58,7 @@ def from_str(cls, str_value): return item @classmethod - def iterate_str_values(cls, ignore_none=True): + def iterate_str_values(cls, ignore_none: bool = True) -> Generator[Any, None, None]: """Iterates over the string values of the enum, ignores entries without a string value by default """ diff --git a/qcore/coordinates.py b/qcore/coordinates.py index 2f95c668..0b85f465 100644 --- a/qcore/coordinates.py +++ b/qcore/coordinates.py @@ -17,8 +17,6 @@ [0]: https://www.linz.govt.nz/guidance/geodetic-system/coordinate-systems-used-new-zealand/projections/new-zealand-transverse-mercator-2000-nztm2000 """ -from typing import Union - import numpy as np import numpy.typing as npt import pyproj @@ -114,8 +112,8 @@ def nztm_to_wgs_depth(nztm_coordinates: np.ndarray) -> np.ndarray: def distance_between_wgs_depth_coordinates( - point_a: np.ndarray, point_b: np.ndarray -) -> Union[float, np.ndarray]: + point_a: npt.ArrayLike, point_b: npt.ArrayLike +) -> npt.ArrayLike: """Return the distance between two points in lat, lon, depth format. Valid only for points that can be converted into NZTM format. @@ -136,6 +134,8 @@ def distance_between_wgs_depth_coordinates( The distance (in metres) between point_a and point_b. Will return an array of floats if input contains multiple points """ + point_a = np.asarray(point_a) + point_b = np.asarray(point_b) if len(point_a.shape) > 1: return np.linalg.norm( wgs_depth_to_nztm(point_a) - wgs_depth_to_nztm(point_b), axis=1 @@ -206,16 +206,17 @@ def great_circle_bearing_to_nztm_bearing( The equivalent bearing such that: `geo.ll_shift`(*`origin`, `distance`, `ll_bearing`) ≅ nztm_heading. """ - great_circle_heading = np.array( - geo.ll_shift(*origin, distance, great_circle_bearing) - ) - - return geo.oriented_bearing_wrt_normal( - np.array([1, 0, 0]), - np.append( - wgs_depth_to_nztm(great_circle_heading) - wgs_depth_to_nztm(origin), 0 - ), - np.array([0, 0, 1]), + x, y = origin + great_circle_heading = np.array(geo.ll_shift(x, y, distance, great_circle_bearing)) + + return float( + geo.oriented_bearing_wrt_normal( + np.array([1, 0, 0]), + np.append( + wgs_depth_to_nztm(great_circle_heading) - wgs_depth_to_nztm(origin), 0 + ), + np.array([0, 0, 1]), + ) ) @@ -252,43 +253,43 @@ class SphericalProjection: Rotation angle in the projected plane in degrees. """ - def __init__(self, mlon: float, mlat: float, mrot: float, radius: float = R_EARTH): # noqa: D107 + def __init__(self, mlon: float, mlat: float, mrot: float, radius: float = R_EARTH): # noqa: D107 # numpydoc ignore=GL08 self.mlon = mlon self.mlat = mlat self.mrot = mrot self.radius = radius arg = np.radians(mrot) - cosA = np.cos(arg) - sinA = np.sin(arg) + cos_a = np.cos(arg) + sin_a = np.sin(arg) arg = np.radians(90.0 - mlat) - cosT = np.cos(arg) - sinT = np.sin(arg) + cos_t = np.cos(arg) + sin_t = np.sin(arg) arg = np.radians(mlon) - cosP = np.cos(arg) - sinP = np.sin(arg) + cos_p = np.cos(arg) + sin_p = np.sin(arg) self.amat = np.array( [ [ - cosA * cosT * cosP + sinA * sinP, - sinA * cosT * cosP - cosA * sinP, - sinT * cosP, + cos_a * cos_t * cos_p + sin_a * sin_p, + sin_a * cos_t * cos_p - cos_a * sin_p, + sin_t * cos_p, ], [ - cosA * cosT * sinP - sinA * cosP, - sinA * cosT * sinP + cosA * cosP, - sinT * sinP, + cos_a * cos_t * sin_p - sin_a * cos_p, + sin_a * cos_t * sin_p + cos_a * cos_p, + sin_t * sin_p, ], - [-cosA * sinT, -sinA * sinT, cosT], + [-cos_a * sin_t, -sin_a * sin_t, cos_t], ], dtype=np.float64, ) @property - def geod(self) -> pyproj.Geod: + def geod(self) -> pyproj.Geod: # numpydoc ignore=RT01 """pyproj.Geod: A pyproj representation of the EMOD3D earth as a Geod.""" return pyproj.Geod(ellps="sphere", a=self.radius, b=self.radius) diff --git a/qcore/formats.py b/qcore/formats.py index 1fb1f9ca..1b139032 100644 --- a/qcore/formats.py +++ b/qcore/formats.py @@ -3,15 +3,17 @@ """ import argparse - -# For some reason, ty can't find the deprecated member of the warnings module -from warnings import deprecated # type: ignore +from pathlib import Path +from typing import overload import pandas as pd +from typing_extensions import deprecated -@deprecated -def load_im_file_pd(imcsv, all_ims=False, comp=None): +@deprecated("Will be removed after Cybershake investigation concludes.") +def load_im_file_pd( + imcsv: Path | str, all_ims: bool = False, comp: str | None = None +) -> pd.DataFrame | pd.Series: """ Loads an IM file using pandas and returns a dataframe :param imcsv: FFP to im_csv @@ -21,6 +23,7 @@ def load_im_file_pd(imcsv, all_ims=False, comp=None): :param comp: component to return. Default is to return all :return: """ + df = pd.read_csv(imcsv, index_col=[0, 1]) if not all_ims: @@ -32,8 +35,20 @@ def load_im_file_pd(imcsv, all_ims=False, comp=None): return df -@deprecated -def station_file_argparser(parser=None): +@overload +def station_file_argparser( + parser: argparse.ArgumentParser, +) -> None: ... # numpydoc ignore=GL08 + + +@overload +def station_file_argparser() -> argparse.ArgumentParser: ... # numpydoc ignore=GL08 + + +@deprecated("Will be removed after Cybershake investigation concludes.") +def station_file_argparser( + parser: argparse.ArgumentParser | None = None, +) -> argparse.ArgumentParser | None: """ Return a parser object with formatting information of a generic station file. To facilitate the use of load_generic_station_file() @@ -108,17 +123,17 @@ def get_args(): return parser -@deprecated +@deprecated("Will be removed after Cybershake investigation concludes.") def load_generic_station_file( stat_file: str, stat_name_col: int = 2, lon_col: int = 0, lat_col: int = 1, - other_cols=[], - other_names=[], - sep=r"\s+", - skiprows=0, -): + other_cols: list[int] | None = None, + other_names: list[str] | None = None, + sep: str = r"\s+", + skiprows: int = 0, +) -> pd.DataFrame: """ Reads the station file of any format into a pandas dataframe @@ -140,30 +155,31 @@ def load_generic_station_file( pd.DataFrame station as index and columns lon, lat and other columns """ - cols = {"stat_name": stat_name_col} + cols: dict[str, int] = {"stat_name": stat_name_col} if lon_col is not None: cols["lon"] = lon_col if lat_col is not None: cols["lat"] = lat_col - for i, col_idx in enumerate(other_cols): - cols[other_names[i]] = col_idx - + if other_cols and other_names: + for col_idx, col_name in zip(other_cols, other_names): + cols[col_name] = col_idx return pd.read_csv( stat_file, - usecols=cols.values(), # we will be loading columns of these indices (order doesn't matter) + # we will be loading columns of these indices (order doesn't matter) + usecols=list(cols.values()), names=sorted( cols, key=cols.get - ), # eg. cols={stat_name:2, lon:0, lat:1} means names = ["lon","lat","stat_name"] + ), # eg. cols={stat_name:2, lon:0, lat:1} means names = ["lon","lat","stat_name"] # type: ignore[no-matching-overload] index_col=stat_name_col, sep=sep, header=None, skiprows=skiprows, - ) + ) # type: ignore -@deprecated -def load_station_file(station_file: str): +@deprecated("Will be removed after Cybershake investigation concludes.") +def load_station_file(station_file: str) -> pd.DataFrame: """Reads the station file into a pandas dataframe Parameters @@ -183,11 +199,11 @@ def load_station_file(station_file: str): names=["lon", "lat"], engine="c", delim_whitespace=True, - ) + ) # type: ignore[no-matching-overload] -@deprecated -def load_vs30_file(vs30_file: str): +@deprecated("Will be removed after Cybershake investigation concludes.") +def load_vs30_file(vs30_file: str) -> pd.DataFrame: """Reads the vs30 file into a pandas dataframe :param vs30_file: Path to the vs30 file @@ -197,8 +213,8 @@ def load_vs30_file(vs30_file: str): return pd.read_csv(vs30_file, sep=r"\s+", index_col=0, header=None, names=["vs30"]) -@deprecated -def load_z_file(z_file: str): +@deprecated("Will be removed after Cybershake investigation concludes.") +def load_z_file(z_file: str) -> pd.DataFrame: """Reads the z file into a pandas dataframe :param z_file: Path to the z file @@ -208,8 +224,8 @@ def load_z_file(z_file: str): return pd.read_csv(z_file, names=["z1p0", "z2p5", "sigma"], index_col=0, skiprows=1) -@deprecated -def load_station_ll_vs30(station_file: str, vs30_file: str): +@deprecated("Will be removed after Cybershake investigation concludes.") +def load_station_ll_vs30(station_file: str, vs30_file: str) -> pd.DataFrame: """Reads both station and vs30 file into a single pandas dataframe - keeps only the matching entries :param station_file: Path to the station file @@ -217,15 +233,14 @@ def load_station_ll_vs30(station_file: str, vs30_file: str): :return: pd.DataFrame station as index and columns lon, lat, vs30 """ - - vs30_df = load_vs30_file(vs30_file) - station_df = load_station_file(station_file) + vs30_df = load_vs30_file(vs30_file) # type: ignore + station_df = load_station_file(station_file) # type: ignore return vs30_df.merge(station_df, left_index=True, right_index=True) -@deprecated -def load_rrup_file(rrup_file: str): +@deprecated("Will be removed after Cybershake investigation concludes.") +def load_rrup_file(rrup_file: str) -> pd.DataFrame: """Reads the rrup file into a pandas dataframe Parameters @@ -241,14 +256,14 @@ def load_rrup_file(rrup_file: str): return pd.read_csv(rrup_file, header=0, index_col=0, engine="c") -@deprecated -def load_fault_selection_file(fault_selection_file): +@deprecated("Will be removed after Cybershake investigation concludes.") +def load_fault_selection_file(fault_selection_file: str | Path) -> dict[str, int]: """ Loads a fault selection file, returning a dictionary of fault:count pairs :param fault_selection_file: The relative or absolute path to the fault selection file :return: A dictionary of fault:count pairs for all faults found in the file """ - faults = {} + faults: dict[str, int] = {} with open(fault_selection_file) as fault_file: for lineno, line in enumerate(fault_file.readlines()): if len(line) == 0 or len(line.lstrip()) == 0 or line.lstrip()[0] == "#": diff --git a/qcore/geo.py b/qcore/geo.py index 51496b9e..8721feff 100644 --- a/qcore/geo.py +++ b/qcore/geo.py @@ -3,7 +3,6 @@ """ from math import acos, asin, atan, atan2, cos, degrees, pi, radians, sin, sqrt -from typing import Optional, Union import numpy as np import numpy.typing as npt @@ -12,7 +11,7 @@ def get_distances( - locations: np.ndarray, lon: Union[float, np.ndarray], lat: Union[float, np.ndarray] + locations: np.ndarray, lon: float | np.ndarray, lat: float | np.ndarray ) -> np.ndarray: """ Calculates the distance between the array of locations and @@ -53,7 +52,7 @@ def closest_location( d = get_distances(locations, lon, lat) i = np.argmin(d) - return i, d[i] + return int(i), d[i] def oriented_bearing_wrt_normal( @@ -347,7 +346,7 @@ def avg_wbearing(angles: list[list[float]]) -> float: def path_from_corners( corners: list[tuple[float, float]], - output: str = "sim.modelpath_hr", + output: str | None = "sim.modelpath_hr", min_edge_points: int = 100, close: bool = True, ): @@ -446,9 +445,9 @@ def ll_cross_along_track_dist( lat2: float, lon3: float, lat3: float, - a12: Optional[float] = None, - a13: Optional[float] = None, - d13: Optional[float] = None, + a12: float | None = None, + a13: float | None = None, + d13: float | None = None, ) -> tuple[float, float]: """ Returns both the distance of point 3 to the nearest point on the great circle line that passes through point 1 and diff --git a/qcore/grid.py b/qcore/grid.py index df4179d9..61b0d83f 100644 --- a/qcore/grid.py +++ b/qcore/grid.py @@ -18,6 +18,7 @@ import numpy as np from qcore import coordinates +from qcore.typing import TFloat def grid_corners( @@ -148,8 +149,8 @@ def coordinate_meshgrid( length_x = np.linalg.norm(x_upper - origin) length_y = np.linalg.norm(y_bottom - origin) - nx = nx or gridpoint_count_in_length(length_x, resolution) - ny = ny or gridpoint_count_in_length(length_y, resolution) + nx = nx or gridpoint_count_in_length(float(length_x), resolution) + ny = ny or gridpoint_count_in_length(float(length_y), resolution) # We first create a meshgrid of coordinates across a flat rectangle like the following # @@ -203,7 +204,7 @@ def coordinate_patchgrid( origin: np.ndarray, x_upper: np.ndarray, y_bottom: np.ndarray, - resolution: Optional[float] = None, + resolution: Optional[TFloat] = None, nx: Optional[int] = None, ny: Optional[int] = None, ) -> np.ndarray: @@ -261,8 +262,11 @@ def coordinate_patchgrid( len_x = np.linalg.norm(v_x) len_y = np.linalg.norm(v_y) - nx = nx or max(1, round(len_x / resolution)) - ny = ny or max(1, round(len_y / resolution)) + if not resolution and not (nx and ny): + raise ValueError("If resolution is not provided, nx and ny must be.") + + nx = nx or max(1, round(float(len_x / resolution))) # type: ignore + ny = ny or max(1, round(float(len_y / resolution))) # type: ignore alpha, beta = np.meshgrid( # The 1 / (2 * nx) term is to ensure that the patches are centred on the grid points. diff --git a/qcore/nhm.py b/qcore/nhm.py index fb548db6..1359d9fc 100644 --- a/qcore/nhm.py +++ b/qcore/nhm.py @@ -86,7 +86,9 @@ class NHMFault: # TODO: add x y z fault plane data as in SRF info # TODO: add leonard mw function - def sample_2012(self, mw_area_scaling: bool = True, mw_perturbation: bool = True): + def sample_2012( + self, mw_area_scaling: bool = True, mw_perturbation: bool = True + ) -> "NHMFault": """ Permutates the current NHM fault as per the OpenSHA implementation. This uses the same Mw scaling relations as Stirling 2012 @@ -149,7 +151,7 @@ def sample_2012(self, mw_area_scaling: bool = True, mw_perturbation: bool = True trace=self.trace, ) - def write(self, out_fp: TextIO, header: bool = False): + def write(self, out_fp: TextIO, header: bool = False) -> None: """ Writes a section of the NHM file @@ -179,7 +181,7 @@ def write(self, out_fp: TextIO, header: bool = False): def load_nhm( nhm_path: str | None = None, skiprows: int = len(NHM_HEADER.splitlines()) + 1 -): +) -> dict[str, NHMFault]: """Reads the nhm_path and returns a dictionary of NHMFault by fault name. Parameters @@ -299,7 +301,9 @@ def load_nhm_df(nhm_ffp: str, erf_name: str | None = None): return pd.DataFrame.from_dict(rupture_dict, orient="index").sort_index() -def get_fault_header_points(fault: NHMFault): +def get_fault_header_points( + fault: NHMFault, +) -> tuple[list[dict[str, int | float]], np.ndarray]: """ Calculates and produces fault information such as the entire trace and fault header info per plane diff --git a/qcore/point_in_polygon.py b/qcore/point_in_polygon.py index 2d0a0ed1..e90e8add 100644 --- a/qcore/point_in_polygon.py +++ b/qcore/point_in_polygon.py @@ -1,10 +1,17 @@ -from numba import jit, njit +from typing import Literal + import numba import numpy as np +import numpy.typing as npt +from numba import jit, njit + +from qcore.typing import TNFloat @jit(nopython=True) -def is_inside_postgis(polygon: np.ndarray, point: np.ndarray): +def is_inside_postgis( + polygon: npt.NDArray[TNFloat], point: npt.NDArray[TNFloat] +) -> Literal[0, 1, 2]: # pragma: no cover """ Function that checks if a point is inside a polygon Based on solutions found here @@ -38,7 +45,7 @@ def is_inside_postgis(polygon: np.ndarray, point: np.ndarray): dy2 = point[1] - polygon[jj][1] # Check if the point is on the polygon - F = (dx - dx2) * dy - dx * (dy - dy2) + F = (dx - dx2) * dy - dx * (dy - dy2) # noqa: N806 if 0.0 == F and dx * dx2 <= 0 and dy * dy2 <= 0: return 2 @@ -50,11 +57,13 @@ def is_inside_postgis(polygon: np.ndarray, point: np.ndarray): jj += 1 - return intersections != 0 + return 1 if intersections != 0 else 0 @njit(parallel=True) -def is_inside_postgis_parallel(points: np.ndarray, polygon: np.ndarray): +def is_inside_postgis_parallel( + points: npt.NDArray[TNFloat], polygon: npt.NDArray[TNFloat] +) -> npt.NDArray[np.bool_]: # pragma: no cover """ Function that checks if a set of points is inside a polygon in parallel @@ -71,7 +80,7 @@ def is_inside_postgis_parallel(points: np.ndarray, polygon: np.ndarray): List of boolean values that indicate if the point is inside the polygon """ ln = len(points) - D = np.empty(ln, dtype=numba.boolean) - for i in numba.prange(ln): + D = np.empty(ln, dtype=np.bool_) # noqa: N806 + for i in numba.prange(ln): # type: ignore D[i] = is_inside_postgis(polygon, points[i]) return D diff --git a/qcore/shared.py b/qcore/shared.py index a83aa287..4cc97180 100644 --- a/qcore/shared.py +++ b/qcore/shared.py @@ -5,11 +5,12 @@ import re import subprocess import sys -from io import IOBase +from io import FileIO from pathlib import Path from typing import AnyStr, Optional, Union import pandas as pd +from typing_extensions import deprecated def get_stations( @@ -99,15 +100,15 @@ def get_corners( return corners, cnr_str +@deprecated("use subprocess.run or subprocess.check_call") def non_blocking_exe( cmd: Union[str, list[str]], debug: bool = True, - stdout: Union[bool, IOBase] = True, - stderr: Union[bool, IOBase] = True, + stdout: Union[bool, FileIO] = True, + stderr: Union[bool, FileIO] = True, **kwargs, -) -> subprocess.Popen: - """Run a command without blocking the calling thread. - +) -> subprocess.Popen: # pragma: no cover + """ *DO NOT USE THIS FUNCTION* Instead, call subprocess.run or subprocess.check_call to execute processes. @@ -148,28 +149,31 @@ def non_blocking_exe( if debug: virtual_cmd = " ".join(cmd) - if isinstance(stdout, IOBase): + if isinstance(stdout, FileIO): virtual_cmd += f" 1>{stdout.name}" - if isinstance(stderr, IOBase): + if isinstance(stderr, FileIO): virtual_cmd += f" 2>{stderr.name}" print(virtual_cmd, file=sys.stderr) # special cases for stderr and stdout + stdout_pipe = stdout + stderr_pipe = stderr if stdout is True: - stdout = subprocess.PIPE + stdout_pipe = subprocess.PIPE if stderr is True: - stderr = subprocess.PIPE + stderr_pipe = subprocess.PIPE - p = subprocess.Popen(cmd, stdout=stdout, stderr=stderr, **kwargs) + p = subprocess.Popen(cmd, stdout=stdout_pipe, stderr=stderr_pipe, **kwargs) return p +@deprecated("use subprocess.run or subprocess.check_call") def exe( cmd: Union[str, list[str]], debug: bool = True, stdin: Optional[AnyStr] = None, **kwargs, -) -> Union[tuple[str, str], tuple[bytes, bytes]]: +) -> Union[tuple[str, str], tuple[bytes, bytes]]: # pragma: no cover """ Runs a command in the shell using the provided parameters. @@ -199,7 +203,7 @@ def exe( conversion fails. """ - exe_process = non_blocking_exe(cmd, debug=debug, **kwargs) + exe_process = non_blocking_exe(cmd, debug=debug, **kwargs) # type: ignore out, err = exe_process.communicate(stdin) _ = exe_process.wait() diff --git a/qcore/simulation_structure.py b/qcore/simulation_structure.py index 772f7210..881dfc38 100644 --- a/qcore/simulation_structure.py +++ b/qcore/simulation_structure.py @@ -5,27 +5,21 @@ import os -def get_fault_from_realisation(realisation): +def get_fault_from_realisation(realisation: str) -> str: realisation = os.path.basename(realisation) # if realisation is a fullpath return realisation.rsplit("_REL", 1)[0] -def get_realisation_name(fault_name, rel_no): +def get_realisation_name(fault_name: str, rel_no: int) -> str: return f"{fault_name}_REL{rel_no:0>2}" -# SRF -def get_srf_location(realisation): - fault = get_fault_from_realisation(realisation) - return os.path.join(fault, "Srf", realisation + ".srf") - - -def get_srf_info_location(realisation): +def get_srf_info_location(realisation: str) -> str: fault = get_fault_from_realisation(realisation) return os.path.join(fault, "Srf", realisation + ".info") -def get_srf_dir(cybershake_root, realisation): +def get_srf_dir(cybershake_root: str, realisation: str) -> str: return os.path.join( cybershake_root, "Data", @@ -35,59 +29,46 @@ def get_srf_dir(cybershake_root, realisation): ) -def get_srf_path(cybershake_root, realisation): - return os.path.join( - cybershake_root, "Data", "Sources", get_srf_location(realisation) - ) - - -# Stoch -def get_stoch_location(realisation): +def get_srf_location(realisation: str) -> str: fault = get_fault_from_realisation(realisation) - return os.path.join(fault, "Stoch", realisation + ".stoch") + return os.path.join(fault, "Srf", realisation + ".srf") -def get_stoch_path(cybershake_root, realisation): +def get_srf_path(cybershake_root: str, realisation: str) -> str: return os.path.join( - cybershake_root, "Data", "Sources", get_stoch_location(realisation) + cybershake_root, "Data", "Sources", get_srf_location(realisation) ) -def get_runs_dir(cybershake_root): - """Gets the path to the Runs directory of a cybershake run""" - return os.path.join(cybershake_root, "Runs") - - -def get_fault_dir(cybershake_root, fault_name): - return os.path.join(get_runs_dir(cybershake_root), fault_name) +def get_fault_dir(cybershake_root: str, fault_name: str) -> str: + return os.path.join(cybershake_root, "Runs", fault_name) -def get_sim_dir(cybershake_root, realisation): +def get_sim_dir(cybershake_root: str, realisation: str) -> str: return os.path.join( get_fault_dir(cybershake_root, get_fault_from_realisation(realisation)), realisation, ) -def get_im_calc_dir(sim_root, realisation=None): +def get_im_calc_dir(sim_root: str, realisation: str | None = None) -> str: if realisation is None: return os.path.join(sim_root, "IM_calc") else: return get_im_calc_dir(get_sim_dir(sim_root, realisation)) -def get_IM_csv_from_root(cybershake_root, realisation): +def get_IM_csv_from_root(cybershake_root: str, realisation: str) -> str: # noqa: N802 return os.path.join( get_im_calc_dir(get_sim_dir(cybershake_root, realisation)), "{}.{}".format(realisation, "csv"), ) -def get_fault_yaml_path(sim_root, fault_name=None): +def get_fault_yaml_path(sim_root: str, fault_name: str | None = None) -> str: """ Gets the fault_params.yaml for the specified simulation. Note: For the manual workflow set fault_name to None as the fault params are stored directly in the simulation directory. """ - fault_name = "" if fault_name is None else fault_name - return os.path.join(sim_root, fault_name, "fault_params.yaml") + return os.path.join(sim_root, fault_name or "", "fault_params.yaml") diff --git a/qcore/siteamp_models.py b/qcore/siteamp_models.py index e8d7094d..177e7c00 100644 --- a/qcore/siteamp_models.py +++ b/qcore/siteamp_models.py @@ -87,7 +87,7 @@ def _fs_low( c10: np.ndarray, k1: np.ndarray, k2: np.ndarray, -) -> np.ndarray: +) -> np.ndarray: # pragma: no cover """Compute site factor based on vs30 value - low code path Parameters @@ -139,7 +139,9 @@ def _fs_mid( @njit -def _fs_high(t_idx: int, c10: np.ndarray, k1: np.ndarray, k2: np.ndarray): +def _fs_high( + t_idx: int, c10: np.ndarray, k1: np.ndarray, k2: np.ndarray +): # pragma: no cover """Compute site factor based on vs30 value - high code path Parameters @@ -167,7 +169,7 @@ def _compute_fs_value( c10: np.ndarray, k1: np.ndarray, k2: np.ndarray, -): +): # pragma: no cover """Compute site factor based on vs30 value Parameters @@ -201,7 +203,7 @@ def _cb_amp( version: int = 2014, flowcap: float = 0.0, freqs: np.ndarray = AMPLIFICATION_FREQUENCIES, -) -> np.ndarray: +) -> np.ndarray: # pragma: no cover """ Numba translation of cb_amp. @@ -375,7 +377,7 @@ def _cb_amp_multi( version: int, flowcap: float, freqs: np.ndarray, -) -> np.ndarray: +) -> np.ndarray: # pragma: no cover """Numba version of cb_amp that processes multiple parameter sets. Parameters @@ -526,7 +528,7 @@ def cb_amp_multi( # Use pga for reference dtype because it is more reliably a float, # where vref can sometimes be an int. - freqs = freqs.astype(pga.dtype) + freqs = freqs.astype(pga.dtype) # type: ignore[no-matching-overload] # Call the numba-accelerated function results = _cb_amp_multi( vref=vref, @@ -584,7 +586,9 @@ def cb2014_to_fas_amplification_factors( @njit( parallel=True, ) -def interp_2d(x: np.ndarray, xp: np.ndarray, fp: np.ndarray) -> np.ndarray: +def interp_2d( + x: np.ndarray, xp: np.ndarray, fp: np.ndarray +) -> np.ndarray: # pragma: no cover """Perform interpolation of a vector-valued function f at `x` with interpolation nodes `xp` and `fp`. This handles the case where `fp` is not 1-D. Interpolation is @@ -688,7 +692,7 @@ def amp_bandpass( fmidbot: float, fmin: float, fftfreq: np.ndarray, -) -> np.ndarray: +) -> np.ndarray: # pragma: no cover """Frequency-dependent amplification adjustment for site amplification factors. This function applies frequency-dependent amplification adjustments @@ -722,7 +726,7 @@ def amp_bandpass( logarithmically between (fmin, fmidbot]. fmin : float The minimum frequency. Amplification is set to 1 below this frequency. - ftfreq : np.ndarray + fftfreq : np.ndarray A 1D array of Fourier transform frequencies corresponding to the amplification values. diff --git a/qcore/src_site_dist.py b/qcore/src_site_dist.py index 1b2ea47d..2af181b6 100644 --- a/qcore/src_site_dist.py +++ b/qcore/src_site_dist.py @@ -3,15 +3,30 @@ History of this file: https://github.com/ucgmsim/IM_calculation/commits/afa9bf02d5e197300e3a91f87a9136b4ebcabd62/IM_calculation/source_site_dist/src_site_dist.py """ -from typing import List, Dict -import matplotlib.path as mpltPath +from typing import Literal, overload + import numpy as np from qcore import geo -VOLCANIC_FRONT_COORDS = [(175.508, -39.364), (177.199, -37.73)] -VOLCANIC_FRONT_LINE = mpltPath.Path(VOLCANIC_FRONT_COORDS) + +@overload +def calc_rrup_rjb( + srf_points: np.ndarray, + locations: np.ndarray, + n_stations_per_iter: int = 1000, + return_rrup_points: Literal[True] = True, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ... # numpydoc ignore=GL08 + + +@overload +def calc_rrup_rjb( + srf_points: np.ndarray, + locations: np.ndarray, + n_stations_per_iter: int = 1000, + return_rrup_points: Literal[False] = False, +) -> tuple[np.ndarray, np.ndarray]: ... # numpydoc ignore=GL08 def calc_rrup_rjb( @@ -19,7 +34,7 @@ def calc_rrup_rjb( locations: np.ndarray, n_stations_per_iter: int = 1000, return_rrup_points: bool = False, -): +) -> tuple[np.ndarray, np.ndarray] | tuple[np.ndarray, np.ndarray, np.ndarray]: """Calculates rrup and rjb distance Parameters @@ -83,11 +98,11 @@ def calc_rrup_rjb( def calc_rx_ry( srf_points: np.ndarray, - plane_infos: List[Dict], + plane_infos: list[dict], locations: np.ndarray, hypocentre_origin: bool = False, type: int = 2, -): +): # pragma: no cover """ A wrapper script allowing external function calls to resolve to the correct location. @@ -120,9 +135,9 @@ def calc_rx_ry( raise ValueError(f"Invalid GC type. {type} not in {{1,2}}") -def calc_rx_ry_GC1( - srf_points: np.ndarray, plane_infos: List[Dict], locations: np.ndarray -): +def calc_rx_ry_GC1( # noqa: N802 + srf_points: np.ndarray, plane_infos: list[dict], locations: np.ndarray +): # pragma: no cover """ Calculates Rx and Ry distances using the cross track and along track distance calculations. Uses the plane nearest to each of the given locations if there are multiple. @@ -199,19 +214,21 @@ def calc_rx_ry_GC1( up_strike_top_point, ) + tp_lon, tp_lat = up_strike_top_point + ds_lon, ds_lat = down_strike_top_point r_x[iloc], r_y[iloc] = geo.ll_cross_along_track_dist( - *up_strike_top_point, *down_strike_top_point, lon, lat + tp_lon, tp_lat, ds_lon, ds_lat, lon, lat ) return r_x, r_y -def calc_rx_ry_GC2( +def calc_rx_ry_GC2( # noqa: N802 srf_points: np.ndarray, - plane_infos: List[Dict], + plane_infos: list[dict], locations: np.ndarray, hypocentre_origin: bool = False, -): +): # pragma: no cover """ Calculates Rx and Ry distances using the cross track and along track distance calculations. If there are multiple fault planes, the Rx, Ry values are calculated for each fault plane individually, then weighted @@ -255,12 +272,12 @@ def calc_rx_ry_GC2( return r_x[0], r_y[0] -def calc_rx_ry_GC2_multi_hypocentre( +def calc_rx_ry_GC2_multi_hypocentre( # noqa: N802 srf_points: np.ndarray, - plane_infos: List[Dict], + plane_infos: list[dict], locations: np.ndarray, origin_offsets: np.ndarray = np.asarray([0]), -): +): # pragma: no cover """ Vectorised version of the GC2 calculation along multiple hypocentre locations. Calculates Rx and Ry distances using the cross track and along track distance calculations. @@ -324,42 +341,3 @@ def calc_rx_ry_GC2_multi_hypocentre( r_y = r_y_values / weights return r_x, r_y - - -def calc_backarc(srf_points: np.ndarray, locations: np.ndarray): - """ - This is a crude approximation of stations that are on the backarc. Defined by source-site lines that cross the - Volcanic front line. - https://user-images.githubusercontent.com/25143301/111406807-ce5bb600-8737-11eb-9c78-b909efe7d9db.png - https://user-images.githubusercontent.com/25143301/111408728-93a74d00-873a-11eb-9afa-5e8371ee2504.png - - Parameters - ---------- - srf_points : np.ndarray - The fault points from the srf file (qcore, srf.py, read_srf_points), - format (lon, lat, depth). - locations : np.ndarray - The locations for which to calculate the distances, - format (lon, lat, depth). - - Returns - ------- - np.ndarray - A numpy array returning 0 if the station is on the forearc and 1 if the station is on the backarc. - """ - n_locations = locations.shape[0] - backarc = np.zeros(n_locations, dtype=np.int) - for loc_index in range(n_locations): - # Selection is every 40 SRF points (4 km) - the backarc line is ~200km long. - # In the case of point sources it will just take the first point - for srf_point in srf_points[::40]: - srf_stat_line = mpltPath.Path( - [ - (srf_point[0], srf_point[1]), - (locations[loc_index][0], locations[loc_index][1]), - ] - ) - if VOLCANIC_FRONT_LINE.intersects_path(srf_stat_line): - backarc[loc_index] = 1 - break - return backarc diff --git a/qcore/timeseries.py b/qcore/timeseries.py index cde4ac93..2b5ad400 100644 --- a/qcore/timeseries.py +++ b/qcore/timeseries.py @@ -10,7 +10,7 @@ import os from enum import StrEnum, auto from pathlib import Path -from typing import NamedTuple +from typing import Literal, NamedTuple import numpy as np import numpy.typing as npt @@ -89,8 +89,11 @@ def bwfilter( case Band.LOWPASS: cutoff_frequencies = taper_frequency * _BW_LOWPASS_SHIFT + btype: Literal["highpass"] | Literal["lowpass"] = ( + "highpass" if band == Band.HIGHPASS else "lowpass" + ) return sp.signal.sosfiltfilt( - sp.signal.butter(4, cutoff_frequencies, btype=band, output="sos", fs=1.0 / dt), + sp.signal.butter(4, cutoff_frequencies, btype=btype, output="sos", fs=1.0 / dt), waveform, padtype=None, ) @@ -131,7 +134,9 @@ def ampdeamp( the values of `amplification_factor`. """ - pyfftw.config.NUM_THREADS = cores + # PyFFTW sets the globals in the config module dynamically. + # So type-checking must be ignored here. + pyfftw.config.NUM_THREADS = cores # type: ignore nt = waveform.shape[-1] waveform_dtype = waveform.dtype @@ -179,79 +184,6 @@ def ampdeamp( return result_full[..., :nt] -def transf( - vs_soil: np.ndarray, - rho_soil: np.ndarray, - damp_soil: np.ndarray, - height_soil: np.ndarray, - vs_rock: np.ndarray, - rho_rock: np.ndarray, - damp_rock: np.ndarray, - nt: int, - dt: float, - ft_freq: np.ndarray | None = None, -) -> np.ndarray: - """Used in de-convolution of high-frequency site-response modelling. - - Can be used instead of traditional Vs30-based site-response when - the relevant input parameters are known. Made by Chris de la - Torre. It is part of the workflow described in [0]_. - - Parameters - ---------- - vs_soil : array of floats - The shear wave velocity in upper soil. - rho_soil : array of floats - The upper soil density. - damp_soil : array of floats - The upper soil damping ratio. - height_soil : array of floats - The height of the upper soil. - vs_rock : array of floats - The shear wave velocity in rock. - rho_rock : array of floats - The rock density. - damp_rock : array of floats - The rock damping ratio. - nt : float - The number of timesteps in input waveform. - dt : float - Waveform timestep. - ft_freq : array of floats, optional - Frequency space of transformed waveform. - - Returns - ------- - np.ndarray - A transfer function `H` used for waveform de-convolution. - - References - ---------- - ..[0] de la Torre, C. A., Bradley, B. A., & Lee, R. L. (2020). Modeling - nonlinear site effects in physics-based ground motion simulations - of the 2010–2011 Canterbury earthquake sequence. Earthquake - Spectra, 36(2), 856-879. - """ - if ft_freq is None: - ft_len = int(2.0 ** np.ceil(np.log(nt) / np.log(2))) - ft_freq = np.arange(0, ft_len / 2 + 1) / (ft_len * dt) - omega = 2.0 * np.pi * ft_freq - Gs = rho_soil * vs_soil**2.0 # noqa: N806 - Gr = rho_rock * vs_rock**2.0 # noqa: N806 - - kS = omega / (vs_soil * (1.0 + 1j * damp_soil)) # noqa: N806 - kR = omega / (vs_rock * (1.0 + 1j * damp_rock)) # noqa: N806 - - alpha = Gs * kS / (Gr * kR) - - H = 2.0 / ( # noqa: N806 - (1.0 + alpha) * np.exp(1j * kS * height_soil) - + (1.0 - alpha) * np.exp(-1j * kS * height_soil) - ) - H[0] = 1 - return H - - _HEAD_STAT = 48 # Header size per station _N_COMP = 9 # Number of components in LF seis files @@ -282,7 +214,12 @@ class LFSeisParser: `fileno`. """ - def __init__(self, handle: io.BufferedReader): # noqa: D107 + handle: io.BufferedIOBase + length: int + i4: str + f4: str + + def __init__(self, handle: io.BufferedIOBase): # noqa: D107 # numpydoc ignore=GL08 self.handle = handle self.length = self._extract_length() self.i4, self.f4 = self._lfseis_dtypes() @@ -622,7 +559,7 @@ def timeseries_to_text( az: float = 0.0, baz: float = 0.0, title: str = "", -): +) -> None: """ Store timeseries data into a text file. diff --git a/qcore/typing.py b/qcore/typing.py new file mode 100644 index 00000000..c827bc4a --- /dev/null +++ b/qcore/typing.py @@ -0,0 +1,8 @@ +"""Typing utilities.""" + +from typing import Any, TypeVar + +import numpy as np + +TNFloat = TypeVar("TNFloat", bound=np.floating[Any]) +TFloat = TypeVar("TFloat", bound=np.floating[Any] | float) diff --git a/qcore/uncertainties/distributions.py b/qcore/uncertainties/distributions.py index 1c3304ae..06c39b28 100644 --- a/qcore/uncertainties/distributions.py +++ b/qcore/uncertainties/distributions.py @@ -13,13 +13,40 @@ - rand_shyp: Generates random hypocentre values along the length of a fault. """ -from typing import Optional +from typing import Literal, overload import numpy as np +import numpy.typing as npt import scipy as sp -def truncated_normal(mean: float, std_dev: float, std_dev_limit: float = 2) -> float: +@overload +def truncated_normal( + mean: float, + std_dev: float, + std_dev_limit: float = 2, + size: Literal[1] = 1, + seed: int | None = None, +) -> float: ... # numpydoc ignore=GL08 + + +@overload +def truncated_normal( + mean: float, + std_dev: float, + std_dev_limit: float = 2, + size: int = 1, + seed: int | None = None, +) -> np.ndarray: ... # numpydoc ignore=GL08 + + +def truncated_normal( + mean: float, + std_dev: float, + std_dev_limit: float = 2, + size: int = 1, + seed: int | None = None, +) -> float | np.ndarray: """ Generate a random value from a truncated normal distribution. @@ -31,23 +58,53 @@ def truncated_normal(mean: float, std_dev: float, std_dev_limit: float = 2) -> f Standard deviation of the normal distribution. std_dev_limit : float, optional Number of standard deviations to limit the truncation (default is 2). + size : int, optional + The number of samples to take (default is 1). + seed : int or None, optional + Random seed for reproducibility (default is None). + Returns ------- - float + float or array of floats Random value from the truncated normal distribution. """ - return sp.stats.truncnorm( - -std_dev_limit, std_dev_limit, loc=mean, scale=std_dev - ).rvs() + x = sp.stats.truncnorm(-std_dev_limit, std_dev_limit, loc=mean, scale=std_dev).rvs( + size=size, random_state=seed + ) + if size == 1: + return float(x.item()) + else: + return x +@overload def truncated_weibull( upper_value: float, c: float = 3.353, scale_factor: float = 0.612, - seed: Optional[int] = None, -) -> float: + size: Literal[1] = 1, + seed: int | None = None, +) -> float: ... # numpydoc ignore=GL08 + + +@overload +def truncated_weibull( + upper_value: float, + c: float = 3.353, + scale_factor: float = 0.612, + size: int = 1, + seed: int | None = None, +) -> np.ndarray: ... # numpydoc ignore=GL08 + + +def truncated_weibull( + upper_value: float, + c: float = 3.353, + scale_factor: float = 0.612, + size: int = 1, + seed: int | None = None, +) -> float | np.ndarray: """ Generate a random value from a truncated Weibull distribution. @@ -59,17 +116,23 @@ def truncated_weibull( Shape parameter of the Weibull distribution (default is 3.353). scale_factor : float, optional Scale factor of the Weibull distribution (default is 0.612). + size : int, optional + The number of samples to take (default is 1). seed : int or None, optional Random seed for reproducibility (default is None). Returns ------- - float + float or array of floats Random value from the truncated Weibull distribution. """ - return upper_value * sp.stats.truncweibull_min( + x = upper_value * sp.stats.truncweibull_min( c, 0, 1 / scale_factor, scale=scale_factor - ).rvs(random_state=seed) + ).rvs(random_state=seed, size=size) + if size == 1: + return float(x.item()) + else: + return x def truncated_weibull_expected_value( @@ -92,15 +155,39 @@ def truncated_weibull_expected_value( float Expected value for the truncated Weibull distribution. """ - return ( + return float( upper_value * sp.stats.truncweibull_min(c, 0, 1 / scale_factor, scale=scale_factor).expect() ) +@overload def truncated_log_normal( - mean: float, std_dev: float, std_dev_limit: float = 2, seed: Optional[int] = None -) -> float: + mean: npt.ArrayLike, + std_dev: float, + std_dev_limit: float = 2, + size: Literal[1] = 1, + seed: int | None = None, +) -> float: ... # numpydoc ignore=GL08 + + +@overload +def truncated_log_normal( + mean: npt.ArrayLike, + std_dev: float, + std_dev_limit: float = 2, + size: int = 1, + seed: int | None = None, +) -> float: ... # numpydoc ignore=GL08 + + +def truncated_log_normal( + mean: npt.ArrayLike, + std_dev: float, + std_dev_limit: float = 2, + size: int = 1, + seed: int | None = None, +) -> float | np.ndarray: """ Generate a random value from a truncated log-normal distribution. @@ -112,6 +199,8 @@ def truncated_log_normal( Standard deviation of the log-normal distribution. std_dev_limit : float, optional Number of standard deviations to limit the truncation (default is 2). + size : int, optional + The number of samples to take (default is 1). seed : int or None, optional Random seed for reproducibility (default is None). @@ -120,7 +209,7 @@ def truncated_log_normal( float Random value from the truncated log-normal distribution. """ - return np.exp( + x = np.exp( sp.stats.truncnorm( -std_dev_limit, std_dev_limit, @@ -128,9 +217,25 @@ def truncated_log_normal( scale=std_dev, ).rvs(random_state=seed) ) + if size == 1: + return float(x.item()) + else: + return x + + +@overload +def rand_shyp( + size: Literal[1] = 1, seed: int | None = None +) -> float: ... # numpydoc ignore=GL08 + + +@overload +def rand_shyp( + size: int = 1, seed: int | None = None +) -> np.ndarray: ... # numpydoc ignore=GL08 -def rand_shyp() -> float: +def rand_shyp(size: int = 1, seed: int | None = None) -> float | np.ndarray: """ Generate a random hypocentre value along the length of a fault. @@ -139,4 +244,8 @@ def rand_shyp() -> float: float Random value from a truncated normal distribution (mean=0, std_dev=0.25). """ - return truncated_normal(0, 0.25) + x = truncated_normal(0, 0.25, size=size, seed=seed) + if size == 1: + return float(x.item()) + else: + return x diff --git a/qcore/utils.py b/qcore/utils.py index 048a5e93..4483e03c 100644 --- a/qcore/utils.py +++ b/qcore/utils.py @@ -3,15 +3,14 @@ Mostly related to file system operations and other non-specific functionality. """ -import os -import re -import shutil from pathlib import Path from typing import Any, Union import yaml +from typing_extensions import deprecated +@deprecated("use yaml.safe_load") def load_yaml(yaml_file: Union[Path, str]) -> Any: """Load YAML from a file. @@ -30,96 +29,3 @@ def load_yaml(yaml_file: Union[Path, str]) -> Any: """ with open(yaml_file, "r", encoding="utf-8") as stream: return yaml.safe_load(stream) - - -def dump_yaml(object: Any, output_name: Union[Path, str]): - """Dump an object to a YAML file. - - *DO NOT USE*. This function exists for backwards compatibility only. Just - use yaml.safe_dump instead. - - Parameters - ---------- - object : Any - The object to dump. - output_name : Union[Path, str] - The filepath to dump to. - """ - with open(output_name, "w", encoding="utf-8") as yaml_file: - yaml.safe_dump(object, yaml_file) - - -def setup_dir(directory: str, empty: bool = False): - """Ensure a directory exists and, optionally, that it is empty. - - Parameters - ---------- - directory : str - The directory to check. - empty : bool - If True, check if the directory is empty. - """ - if os.path.exists(directory) and empty: - shutil.rmtree(directory) - if not os.path.exists(directory): - # multi processing safety (not useful with empty set) - try: - os.makedirs(directory) - except OSError: - if not os.path.isdir(directory): - raise - - -def compare_versions(version1: str, version2: str, split_char: str = ".") -> int: - """Compare two version strings. - - Comparison is made on the individual parts of each version. Where the - versions are equivalent but the number of parts differs, i.e. comparing - 1.0 and 1, the longer version string is considered newer. - - Parameters - ---------- - version1 : str - The first version string to check. - version2 : str - The second version string to check. - split_char : str - The version separator. - - Returns - ------- - int - Returns 1 if version1 is newer than version 2, -1 if version2 is - newer than version1 and 0 otherwise. - - Examples - -------- - >>> compare_versions('1.0.0', '1') - 0 - >>> compare_versions('1.0.1', '1') - 1 - >>> compare_versions('1.0.1', '1.1') - -1 - """ - invalid_version_characters = f"[^0-9{re.escape(split_char)}]" - parts1 = [ - int(part) - for part in re.sub(invalid_version_characters, "", version1).split(split_char) - ] - parts2 = [ - int(part) - for part in re.sub(invalid_version_characters, "", version2).split(split_char) - ] - max_length = max(len(parts1), len(parts2)) - - if parts1[:max_length] > parts2[:max_length]: - return 1 - if parts1[:max_length] < parts2[:max_length]: - return -1 - - if len(parts1) > len(parts2): - return 1 - if len(parts2) > len(parts1): - return -1 - - return 0 diff --git a/qcore/xyts.py b/qcore/xyts.py index dbd3d17a..2a6dc3b4 100644 --- a/qcore/xyts.py +++ b/qcore/xyts.py @@ -39,13 +39,14 @@ """ import dataclasses +from enum import Enum from math import cos, radians, sin from pathlib import Path +from typing import Literal, overload import numpy as np from qcore import geo -from enum import Enum class Component(Enum): @@ -142,16 +143,29 @@ class XYTSFile: nx_sim: int dip: float comps: dict[str, float] - cosR: float - sinR: float - cosP: float - sinP: float + + cos_r: float + + sin_r: float + + cos_p: float + + sin_p: float + rot_matrix: np.ndarray + # proc-local files only local_nx: int | None = None + local_ny: int | None = None + local_nz: int | None = None + # data arrays + data: np.memmap | None = None + + ll_map: np.ndarray | None = None + # contents data: np.memmap | None = ( None # NOTE: this is distinct (but nearly identical to) a np.ndarray @@ -159,13 +173,13 @@ class XYTSFile: ll_map: np.ndarray | None = None - def __init__( + def __init__( # noqa: D107 self, xyts_path: Path | str, meta_only: bool = False, proc_local_file: bool = False, round_dt: bool = True, - ): + ): # numpydoc ignore=GL08 """Initializes the XYTSFile object. Parameters @@ -239,11 +253,11 @@ def __init__( "Z": radians(90 - self.dip), } # rotation of components so Y is true north - self.cosR = cos(self.comps["X"]) - self.sinR = sin(self.comps["X"]) + self.cos_r = cos(self.comps["X"]) + self.sin_r = sin(self.comps["X"]) # simulation plane always flat, dip = 0 - self.cosP = 0 # cos(self.comps['Z']) - self.sinP = 1 # sin(self.comps['Z']) + self.cos_p = 0 # cos(self.comps['Z']) + self.sin_p = 1 # sin(self.comps['Z']) # xy dual component rotation matrix # must also flip vertical axis theta = radians(self.mrot) @@ -256,6 +270,10 @@ def __init__( return if proc_local_file: + if not self.local_ny or not self.local_nx: + raise ValueError( + "Local nx, ny must be set when parsing a process local XYTS file." + ) self.data = np.memmap( xyts_path, dtype="%sf4" % (endian), @@ -288,6 +306,18 @@ def __init__( ll_map[ll_map[:, :, 0] < 0, 0] += 360 self.ll_map = ll_map + @overload + def corners( + self, gmt_format: Literal[True] + ) -> tuple[list[list[float]], str]: ... # numpydoc ignore=GL08 + + @overload + def corners( + self, gmt_format: Literal[False] + ) -> list[list[float]]: ... # numpydoc ignore=GL08 + + @overload + def corners(self) -> list[list[float]]: ... # numpydoc ignore=GL08 def corners( self, gmt_format: bool = False ) -> list[list[float]] | tuple[list[list[float]], str]: @@ -347,7 +377,7 @@ def region( The simulation region as a tuple (x_min, x_max, y_min, y_max). """ if corners is None: - corners = self.corners() + corners: np.ndarray = np.asarray(self.corners()) x_min, y_min = np.min(corners, axis=0) x_max, y_max = np.max(corners, axis=0) @@ -372,18 +402,22 @@ def tslice_get( np.ndarray Retrieved timeslice data. """ + if self.data is None: + raise AttributeError( + "The data attribute must be set to use `tslice_get`. Did you set `meta_only=True` when you initialised the class?" + ) match comp: case Component.MAGNITUDE: return np.linalg.norm(self.data[step, :3, :, :], axis=0) case Component.X: return ( - self.data[step, 0, :, :] * self.sinR - + self.data[step, 1, :, :] * self.cosR + self.data[step, 0, :, :] * self.sin_r + + self.data[step, 1, :, :] * self.cos_r ) case Component.Y: return ( - self.data[step, 0, :, :] * self.cosR - - self.data[step, 1, :, :] * self.sinR + self.data[step, 0, :, :] * self.cos_r + - self.data[step, 1, :, :] * self.sin_r ) case Component.Z: return self.data[step, 2, :, :] * -1 @@ -411,6 +445,11 @@ def pgv( PGV map or tuple of (PGV map, MMI map) or None (if both are written to a file). """ + + if self.data is None or self.ll_map is None: + raise AttributeError( + "The data and ll_map attributes must be set to use `pgv`. Did you set `meta_only=True` when you initialised the class?" + ) # PGV as timeslices reduced to maximum value at each point pgv = np.zeros(self.nx * self.ny) for ts in range(self.t0, self.nt): diff --git a/tests/test_coordinates.py b/tests/test_coordinates.py index 41722809..1e053f42 100644 --- a/tests/test_coordinates.py +++ b/tests/test_coordinates.py @@ -1,13 +1,18 @@ +import re + import numpy as np import pyproj import pytest from hypothesis import given from hypothesis import strategies as st +from qcore import coordinates from qcore.coordinates import R_EARTH, SphericalProjection -def latitude(min_value: float = -90.0, max_value: float = 90.0, **kwargs): +def latitude( + min_value: float = -90.0, max_value: float = 90.0, **kwargs +) -> st.SearchStrategy: return st.floats( min_value=min_value, max_value=max_value, @@ -17,7 +22,9 @@ def latitude(min_value: float = -90.0, max_value: float = 90.0, **kwargs): ) -def longitude(min_value: float = -180.0, max_value: float = 180.0, **kwargs): +def longitude( + min_value: float = -180.0, max_value: float = 180.0, **kwargs +) -> st.SearchStrategy: return st.floats( min_value=min_value, max_value=max_value, @@ -32,7 +39,7 @@ def longitude(min_value: float = -180.0, max_value: float = 180.0, **kwargs): @st.composite -def points_in_same_hemisphere(draw): +def points_in_same_hemisphere(draw: st.DrawFn) -> tuple[float, float, float, float]: mlat = draw(latitude()) mlon = draw(longitude()) # Pick a second pair of points in the same hemisphere @@ -44,7 +51,9 @@ def points_in_same_hemisphere(draw): @given(points=points_in_same_hemisphere(), mrot=st.floats(-360, 360)) -def test_projection_inverse_is_identity(points, mrot): +def test_projection_inverse_is_identity( + points: tuple[float, float, float, float], mrot: float +) -> None: mlat, mlon, lat, lon = points proj = SphericalProjection(mlon, mlat, mrot) fwd = proj.project(lat, lon) @@ -64,7 +73,7 @@ def test_projection_inverse_is_identity(points, mrot): # longitude at the poles is equivalent to staying put (and hence the # mlon +/- eps) tests will always fail. @given(mlat=latitude(exclude_min=True, exclude_max=True), mlon=longitude()) -def test_identity_rotation_preserves_axes(mlat, mlon): +def test_identity_rotation_preserves_axes(mlat: float, mlon: float) -> None: proj = SphericalProjection(mlon, mlat, 0) # The coordinate frame of reference is south is y-positive, west @@ -79,19 +88,21 @@ def test_identity_rotation_preserves_axes(mlat, mlon): @given(mlat=latitude(), mlon=longitude(), mrot=st.floats(-360, 360)) -def test_center_maps_to_origin(mlat, mlon, mrot): +def test_center_maps_to_origin(mlat: float, mlon: float, mrot: float) -> None: proj = SphericalProjection(mlon, mlat, mrot) out = proj.project(mlat, mlon) assert pytest.approx(np.zeros_like(out), abs=1e-3) == out @given(points=points_in_same_hemisphere(), mrot=st.floats(-360, 360)) -def test_projection_preserves_distance(points, mrot): +def test_projection_preserves_distance( + points: tuple[float, float, float, float], mrot: float +) -> None: mlat, mlon, lat, lon = points proj = SphericalProjection(mlon, mlat, mrot) geod = proj.geod # Pyproj spherical geodesic - dist1 = GEOD.inv(mlon, mlat, lon, lat)[2] / 1000.0 + dist1 = geod.inv(mlon, mlat, lon, lat)[2] / 1000.0 x, y = proj.project(lat, lon) dist2 = np.hypot(x, y) @@ -100,7 +111,7 @@ def test_projection_preserves_distance(points, mrot): assert pytest.approx(dist1, abs=0.1) == dist2 -def test_projection_preserves_depth(): +def test_projection_preserves_depth() -> None: # Test that the projection does not change depth proj = SphericalProjection(0, 0, 0) # Centered at the origin depth = 100.0 @@ -110,7 +121,7 @@ def test_projection_preserves_depth(): assert pytest.approx(projected[2]) == depth # Depth should remain unchanged -def test_inverse_projection_preserves_depth(): +def test_inverse_projection_preserves_depth() -> None: # Test that the projection does not change depth proj = SphericalProjection(0, 0, 0) # Centered at the origin depth = 100.0 diff --git a/tests/test_geo.py b/tests/test_geo.py index 617b37f5..60bbb526 100644 --- a/tests/test_geo.py +++ b/tests/test_geo.py @@ -1,3 +1,5 @@ +from pathlib import Path + import numpy as np import pytest from hypothesis import given @@ -8,28 +10,43 @@ @pytest.mark.parametrize( "test_b1, test_b2, expected_angle", - [(0, 360, 0), (80, 180, 100), (320, 0, 40), (180, -180, 0)], + [ + (0, 360, 0), + (80, 180, 100), + (320, 0, 40), + (180, -180, 0), + ( + 10, + 200, + -170, + ), # Test case where result > 180: (200-10)%360=190, returns 190-360=-170 + (0, 270, -90), # Another test case: (270-0)%360=270, returns 270-360=-90 + ], ) -def test_angle_diff(test_b1, test_b2, expected_angle): +def test_angle_diff(test_b1: float, test_b2: float, expected_angle: float) -> None: assert geo.angle_diff(test_b1, test_b2) == expected_angle @pytest.mark.parametrize( "test_lon1, test_lat1, test_lon2, test_lat2, test_midpoint, expected_bearing", [ - (120, 90, 180, 90, True, 74.99999999999997), + (120, 90, 180, 90, True, 75), (0, 0, 180, 0, False, 90), (-45, 0, 90, 0, True, 90), - (170, 90, 180, 90, False, 84.99999999999996), + (170, 90, 180, 90, False, 85), ], ) def test_ll_bearing( - test_lon1, test_lat1, test_lon2, test_lat2, test_midpoint, expected_bearing -): - assert ( - geo.ll_bearing(test_lon1, test_lat1, test_lon2, test_lat2, test_midpoint) - == expected_bearing - ) + test_lon1: float, + test_lat1: float, + test_lon2: float, + test_lat2: float, + test_midpoint: bool, + expected_bearing: float, +) -> None: + assert geo.ll_bearing( + test_lon1, test_lat1, test_lon2, test_lat2, test_midpoint + ) == pytest.approx(expected_bearing) @pytest.mark.parametrize( @@ -41,7 +58,13 @@ def test_ll_bearing( (45, 0, 90, 0, 5009.378656493638), ], ) -def test_ll_dist(test_lon1, test_lat1, test_lon2, test_lat2, expected_dist): +def test_ll_dist( + test_lon1: float, + test_lat1: float, + test_lon2: float, + test_lat2: float, + expected_dist: float, +) -> None: assert geo.ll_dist(test_lon1, test_lat1, test_lon2, test_lat2) == expected_dist @@ -55,8 +78,13 @@ def test_ll_dist(test_lon1, test_lat1, test_lon2, test_lat2, expected_dist): ], ) def test_ll_mid( - test_lon1, test_lat1, test_lon2, test_lat2, expected_mid_lon, expected_mid_lat -): + test_lon1: float, + test_lat1: float, + test_lon2: float, + test_lat2: float, + expected_mid_lon: float, + expected_mid_lat: float, +) -> None: assert geo.ll_mid(test_lon1, test_lat1, test_lon2, test_lat2) == ( expected_mid_lon, expected_mid_lat, @@ -73,8 +101,13 @@ def test_ll_mid( ], ) def test_ll_shift( - test_lat1, test_lon1, test_distance, test_bearing, expected_lat, expected_lon -): + test_lat1: float, + test_lon1: float, + test_distance: float, + test_bearing: float, + expected_lat: float, + expected_lon: float, +) -> None: assert geo.ll_shift(test_lat1, test_lon1, test_distance, test_bearing) == ( expected_lat, expected_lon, @@ -87,15 +120,15 @@ def test_ll_shift( ([[40, 1], [270, 1]], 335), ([[45, 10], [180, 1], [112.5, 2]], 59.252104114837415), ([[45, 1], [180, 1], [112.5, 2]], 112.5), - ([[45, 1], [180, 1]], 112.49999999999999), + ([[45, 1], [180, 1]], 112.5), ], ) -def test_avg_wbearing(test_angles, output_degrees): - assert geo.avg_wbearing(test_angles) == output_degrees +def test_avg_wbearing(test_angles: list[list[float]], output_degrees: float) -> None: + assert geo.avg_wbearing(test_angles) == pytest.approx(output_degrees) @given(target_bearing=st.floats(0, 360)) -def test_oriented_bearing_wrt_normal(target_bearing: float): +def test_oriented_bearing_wrt_normal(target_bearing: float) -> None: to_direction = np.array( [np.cos(np.radians(target_bearing)), np.sin(np.radians(target_bearing)), 0] ) @@ -129,12 +162,14 @@ def test_oriented_bearing_wrt_normal(target_bearing: float): ([1, 0], [2, 2], [0, 0], np.sqrt(0.5)), ], ) -def test_point_to_segment_distance(p, q, r, expected_distance): +def test_point_to_segment_distance( + p: list[float], q: list[float], r: list[float], expected_distance: float +) -> None: """Test the point_to_segment_distance function with various cases.""" assert geo.point_to_segment_distance(p, q, r) == pytest.approx(expected_distance) -def test_point_to_segement_degenerate(): +def test_point_to_segement_degenerate() -> None: """Test the failure case of a degenerate line.""" with pytest.raises(ValueError): geo.point_to_segment_distance([1, 1], [0, 0], [0, 0]) diff --git a/tests/test_xyts.py b/tests/test_xyts.py index 22008210..18d45d92 100644 --- a/tests/test_xyts.py +++ b/tests/test_xyts.py @@ -204,6 +204,5 @@ def test_xyts_invalid_file(tmp_path: Path) -> None: # Create a file with invalid header with open(invalid_file, "wb") as f: f.write(b"\x00" * 100) - with pytest.raises(ValueError, match="File is not an XY timeslice file"): xyts.XYTSFile(str(invalid_file)) From 5d1fb2511f0b792cdc54b472c07f0299345c65c8 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 8 Jan 2026 13:13:35 +1300 Subject: [PATCH 02/12] fix: use ... in overloads --- qcore/src_site_dist.py | 4 ++-- qcore/uncertainties/distributions.py | 28 ++++++++++++++-------------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/qcore/src_site_dist.py b/qcore/src_site_dist.py index 2af181b6..ea0a629e 100644 --- a/qcore/src_site_dist.py +++ b/qcore/src_site_dist.py @@ -15,7 +15,7 @@ def calc_rrup_rjb( srf_points: np.ndarray, locations: np.ndarray, - n_stations_per_iter: int = 1000, + n_stations_per_iter: int = ..., return_rrup_points: Literal[True] = True, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ... # numpydoc ignore=GL08 @@ -24,7 +24,7 @@ def calc_rrup_rjb( def calc_rrup_rjb( srf_points: np.ndarray, locations: np.ndarray, - n_stations_per_iter: int = 1000, + n_stations_per_iter: int = ..., return_rrup_points: Literal[False] = False, ) -> tuple[np.ndarray, np.ndarray]: ... # numpydoc ignore=GL08 diff --git a/qcore/uncertainties/distributions.py b/qcore/uncertainties/distributions.py index 06c39b28..bdb5ffbd 100644 --- a/qcore/uncertainties/distributions.py +++ b/qcore/uncertainties/distributions.py @@ -24,7 +24,7 @@ def truncated_normal( mean: float, std_dev: float, - std_dev_limit: float = 2, + std_dev_limit: float = ..., size: Literal[1] = 1, seed: int | None = None, ) -> float: ... # numpydoc ignore=GL08 @@ -34,7 +34,7 @@ def truncated_normal( def truncated_normal( mean: float, std_dev: float, - std_dev_limit: float = 2, + std_dev_limit: float = ..., size: int = 1, seed: int | None = None, ) -> np.ndarray: ... # numpydoc ignore=GL08 @@ -81,20 +81,20 @@ def truncated_normal( @overload def truncated_weibull( upper_value: float, - c: float = 3.353, - scale_factor: float = 0.612, + c: float = ..., + scale_factor: float = ..., size: Literal[1] = 1, - seed: int | None = None, + seed: int | None = ..., ) -> float: ... # numpydoc ignore=GL08 @overload def truncated_weibull( upper_value: float, - c: float = 3.353, - scale_factor: float = 0.612, + c: float = ..., + scale_factor: float = ..., size: int = 1, - seed: int | None = None, + seed: int | None = ..., ) -> np.ndarray: ... # numpydoc ignore=GL08 @@ -165,9 +165,9 @@ def truncated_weibull_expected_value( def truncated_log_normal( mean: npt.ArrayLike, std_dev: float, - std_dev_limit: float = 2, + std_dev_limit: float = ..., size: Literal[1] = 1, - seed: int | None = None, + seed: int | None = ..., ) -> float: ... # numpydoc ignore=GL08 @@ -175,9 +175,9 @@ def truncated_log_normal( def truncated_log_normal( mean: npt.ArrayLike, std_dev: float, - std_dev_limit: float = 2, + std_dev_limit: float = ..., size: int = 1, - seed: int | None = None, + seed: int | None = ..., ) -> float: ... # numpydoc ignore=GL08 @@ -225,13 +225,13 @@ def truncated_log_normal( @overload def rand_shyp( - size: Literal[1] = 1, seed: int | None = None + size: Literal[1] = 1, seed: int | None = ... ) -> float: ... # numpydoc ignore=GL08 @overload def rand_shyp( - size: int = 1, seed: int | None = None + size: int = 1, seed: int | None = ... ) -> np.ndarray: ... # numpydoc ignore=GL08 From ff4b2d216056660f4f0b9cec04b7507435b0ab52 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 8 Jan 2026 13:16:11 +1300 Subject: [PATCH 03/12] fix: truncated_log_normal return type --- qcore/uncertainties/distributions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/qcore/uncertainties/distributions.py b/qcore/uncertainties/distributions.py index bdb5ffbd..30a46961 100644 --- a/qcore/uncertainties/distributions.py +++ b/qcore/uncertainties/distributions.py @@ -176,9 +176,9 @@ def truncated_log_normal( mean: npt.ArrayLike, std_dev: float, std_dev_limit: float = ..., - size: int = 1, + size: int = ..., seed: int | None = ..., -) -> float: ... # numpydoc ignore=GL08 +) -> np.ndarray: ... # numpydoc ignore=GL08 def truncated_log_normal( From 057f6b0e9ec891050453d79fb78e2966ea20923a Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 8 Jan 2026 13:17:47 +1300 Subject: [PATCH 04/12] fix(distributions): more robust floating point returns --- qcore/uncertainties/distributions.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/qcore/uncertainties/distributions.py b/qcore/uncertainties/distributions.py index 30a46961..f3c8d17d 100644 --- a/qcore/uncertainties/distributions.py +++ b/qcore/uncertainties/distributions.py @@ -72,7 +72,7 @@ def truncated_normal( x = sp.stats.truncnorm(-std_dev_limit, std_dev_limit, loc=mean, scale=std_dev).rvs( size=size, random_state=seed ) - if size == 1: + if x.size == 1: return float(x.item()) else: return x @@ -129,7 +129,7 @@ def truncated_weibull( x = upper_value * sp.stats.truncweibull_min( c, 0, 1 / scale_factor, scale=scale_factor ).rvs(random_state=seed, size=size) - if size == 1: + if x.size == 1: return float(x.item()) else: return x @@ -217,7 +217,7 @@ def truncated_log_normal( scale=std_dev, ).rvs(random_state=seed) ) - if size == 1: + if x.size == 1: return float(x.item()) else: return x @@ -245,7 +245,7 @@ def rand_shyp(size: int = 1, seed: int | None = None) -> float | np.ndarray: Random value from a truncated normal distribution (mean=0, std_dev=0.25). """ x = truncated_normal(0, 0.25, size=size, seed=seed) - if size == 1: + if x.size == 1: return float(x.item()) else: return x From f53dce67654492e8bd3936342f5eab13923e83a6 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 8 Jan 2026 13:22:25 +1300 Subject: [PATCH 05/12] fix(geo): overload `path_from_corners` --- qcore/geo.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/qcore/geo.py b/qcore/geo.py index 8721feff..8fafa2ef 100644 --- a/qcore/geo.py +++ b/qcore/geo.py @@ -3,6 +3,7 @@ """ from math import acos, asin, atan, atan2, cos, degrees, pi, radians, sin, sqrt +from typing import overload import numpy as np import numpy.typing as npt @@ -344,12 +345,30 @@ def avg_wbearing(angles: list[list[float]]) -> float: return degrees(atan(x / y) + q_diff) +@overload +def path_from_corners( + corners: list[tuple[float, float]], + output: str | None = None, + min_edge_points: int = ..., + close: bool = ..., +) -> list[tuple[float, float]]: ... + + +@overload +def path_from_corners( + corners: list[tuple[float, float]], + output: str = ..., + min_edge_points: int = ..., + close: bool = ..., +) -> None: ... + + def path_from_corners( corners: list[tuple[float, float]], output: str | None = "sim.modelpath_hr", min_edge_points: int = 100, close: bool = True, -): +) -> list[tuple[float, float]] | None: """ corners: python list (4 by 2) containing (lon, lat) in order otherwise take from velocity model From f5306459ab59b300cff3d2dc9b111a4b57d48d87 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 8 Jan 2026 13:23:07 +1300 Subject: [PATCH 06/12] fix(xyts): remove duplicate data definitions --- qcore/xyts.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/qcore/xyts.py b/qcore/xyts.py index 2a6dc3b4..909880d3 100644 --- a/qcore/xyts.py +++ b/qcore/xyts.py @@ -161,11 +161,6 @@ class XYTSFile: local_nz: int | None = None - # data arrays - data: np.memmap | None = None - - ll_map: np.ndarray | None = None - # contents data: np.memmap | None = ( None # NOTE: this is distinct (but nearly identical to) a np.ndarray From 82687460e51fa5061f8c3fb2332f12ebac4bae51 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 8 Jan 2026 13:23:50 +1300 Subject: [PATCH 07/12] tests: remove unused imports --- tests/test_coordinates.py | 3 --- tests/test_geo.py | 2 -- 2 files changed, 5 deletions(-) diff --git a/tests/test_coordinates.py b/tests/test_coordinates.py index 1e053f42..16d62b33 100644 --- a/tests/test_coordinates.py +++ b/tests/test_coordinates.py @@ -1,12 +1,9 @@ -import re - import numpy as np import pyproj import pytest from hypothesis import given from hypothesis import strategies as st -from qcore import coordinates from qcore.coordinates import R_EARTH, SphericalProjection diff --git a/tests/test_geo.py b/tests/test_geo.py index 60bbb526..efdff5f9 100644 --- a/tests/test_geo.py +++ b/tests/test_geo.py @@ -1,5 +1,3 @@ -from pathlib import Path - import numpy as np import pytest from hypothesis import given From f0ae08e09799bd566d82620662e89dd4dc6ced54 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 8 Jan 2026 13:24:57 +1300 Subject: [PATCH 08/12] fix(distributions): truncated log normal takes size argument --- qcore/uncertainties/distributions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qcore/uncertainties/distributions.py b/qcore/uncertainties/distributions.py index f3c8d17d..68ea895e 100644 --- a/qcore/uncertainties/distributions.py +++ b/qcore/uncertainties/distributions.py @@ -215,7 +215,7 @@ def truncated_log_normal( std_dev_limit, loc=np.log(np.asarray(mean).astype(np.float64)), scale=std_dev, - ).rvs(random_state=seed) + ).rvs(size=size, random_state=seed) ) if x.size == 1: return float(x.item()) From 129f9e5306720c72ebf8c27a1595990d77a6c194 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 8 Jan 2026 13:25:52 +1300 Subject: [PATCH 09/12] fix(distributions): crash when size=1 for rand_shyp --- qcore/uncertainties/distributions.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/qcore/uncertainties/distributions.py b/qcore/uncertainties/distributions.py index 68ea895e..b91601ac 100644 --- a/qcore/uncertainties/distributions.py +++ b/qcore/uncertainties/distributions.py @@ -244,8 +244,4 @@ def rand_shyp(size: int = 1, seed: int | None = None) -> float | np.ndarray: float Random value from a truncated normal distribution (mean=0, std_dev=0.25). """ - x = truncated_normal(0, 0.25, size=size, seed=seed) - if x.size == 1: - return float(x.item()) - else: - return x + return truncated_normal(0, 0.25, size=size, seed=seed) From e833340a8864d992b01d0ddc356b25d6da7f7bbb Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 8 Jan 2026 13:29:39 +1300 Subject: [PATCH 10/12] fix(constants): short-circuitted logic in is_substring Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- qcore/constants.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qcore/constants.py b/qcore/constants.py index 77b0da01..f2c5357a 100644 --- a/qcore/constants.py +++ b/qcore/constants.py @@ -19,7 +19,7 @@ def has_value(cls, value: Any) -> bool: def is_substring(cls, parent_string: str) -> bool: """Check if an enum's string value is contained in the given string""" return any( - not isinstance(item.value, str) or item.value in parent_string + isinstance(item.value, str) and item.value in parent_string for item in cls ) From 3caa69ff903401ba16cdbd927fc7ff5d1cb79991 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 8 Jan 2026 14:01:44 +1300 Subject: [PATCH 11/12] fix(formats): correct return type station_file_argparser never actually returns None --- qcore/formats.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/qcore/formats.py b/qcore/formats.py index 1b139032..e1f6ec01 100644 --- a/qcore/formats.py +++ b/qcore/formats.py @@ -35,20 +35,10 @@ def load_im_file_pd( return df -@overload -def station_file_argparser( - parser: argparse.ArgumentParser, -) -> None: ... # numpydoc ignore=GL08 - - -@overload -def station_file_argparser() -> argparse.ArgumentParser: ... # numpydoc ignore=GL08 - - @deprecated("Will be removed after Cybershake investigation concludes.") def station_file_argparser( parser: argparse.ArgumentParser | None = None, -) -> argparse.ArgumentParser | None: +) -> argparse.ArgumentParser: """ Return a parser object with formatting information of a generic station file. To facilitate the use of load_generic_station_file() From 534579fb37e286eb5a702f74da5a9b96d4dbaa93 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 8 Jan 2026 14:05:35 +1300 Subject: [PATCH 12/12] ci: run type checker on all PRs --- .github/workflows/types.yml | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/.github/workflows/types.yml b/.github/workflows/types.yml index 0f7a1c85..ea628a43 100644 --- a/.github/workflows/types.yml +++ b/.github/workflows/types.yml @@ -1,10 +1,6 @@ name: Type Check -on: - push: - branches: [master] - pull_request: - branches: [master] +on: [pull_request] jobs: typecheck: @@ -19,7 +15,6 @@ jobs: with: python-version: "3.13" - - name: Install uv uses: astral-sh/setup-uv@v5 with: