Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions .github/workflows/types.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: Type Check

on: [pull_request]

jobs:
typecheck:
runs-on: ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.13"

- name: Install uv
uses: astral-sh/setup-uv@v5
with:
enable-cache: true

- name: Install project with types
run: uv sync --all-extras --dev

- name: Run type checking with ty
run: uv run ty check
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ dependencies = [
"pyproj",
"pooch",
"scipy",
"matplotlib",
"numpy",
"numba>=0.57.0; python_version == '3.11'",
"numba>=0.59.0; python_version == '3.12'",
Expand All @@ -25,6 +24,7 @@ dependencies = [
"docstring_parser",
"xarray",
"pyyaml",
"typing-extensions>=4.9.0", # Last change to the @deprecated module
]

[project.optional-dependencies]
Expand Down
52 changes: 43 additions & 9 deletions qcore/archive_structure.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,61 @@
"""
Gives access to the folder structure of archived cybershake directories
"""

from pathlib import Path

from .simulation_structure import get_fault_from_realisation


def get_fault_source_dir(fault_dir: Path):
"""Gets the Source directory for the given fault directory"""
def get_fault_source_dir(fault_dir: Path) -> Path:
"""
Get the Source directory for a given fault directory.

Parameters
----------
fault_dir : Path
Path to the fault directory.

Returns
-------
Path
Path to the Source directory within the given fault directory.
"""
return fault_dir / "Source"


def get_fault_im_dir(fault_dir: Path):
"""Gets the IM directory for the given fault directory"""
def get_fault_im_dir(fault_dir: Path) -> Path:
"""
Get the IM directory for a given fault directory.

Parameters
----------
fault_dir : Path
Path to the fault directory.

Returns
-------
Path
Path to the IM directory within the given fault directory.
"""
return fault_dir / "IM"


def get_fault_bb_dir(fault_dir: Path):
"""Gets the BB directory for the given fault directory"""
return fault_dir / "BB"
def get_IM_csv_from_root(archive_root: Path, realisation: str) -> Path: # noqa: N802
"""
Get the full path to the IM CSV file given the archive root and realisation name.

Parameters
----------
archive_root : Path
Path to the root directory of the Cybershake archive.
realisation : str
Name of the realisation to locate.

def get_IM_csv_from_root(archive_root: Path, realisation: str):
"""Gets the full path to the im_csv file given the archive root dir and the realistion name"""
Returns
-------
Path
Full path to the IM CSV file for the specified realisation.
"""
fault_name = get_fault_from_realisation(realisation)
return get_fault_im_dir(archive_root / fault_name) / f"{realisation}.csv"
50 changes: 32 additions & 18 deletions qcore/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,48 @@
import inspect
from collections.abc import Callable
from functools import wraps
from typing import Annotated, Any, get_args, get_origin
from typing import (
Annotated,
Any,
ParamSpec,
TypeVar,
get_args,
get_origin,
)

import docstring_parser
import typer
from docstring_parser.common import DocstringStyle
from typer.models import ArgumentInfo, OptionInfo

# P captures the parameters (args and kwargs) of the decorated function.
P = ParamSpec("P")

# Originally written by @Genfood: https://github.com/fastapi/typer/issues/336#issuecomment-2434726193
# Updated and modified for Python 3.13.
def from_docstring(app: typer.Typer, **kwargs: dict) -> Callable:
# R captures the return type of the decorated function.
R = TypeVar("R")


def from_docstring(
app: typer.Typer,
**kwargs: Any,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""Apply help texts from the function's docstring to Typer arguments/options and command.

Parameters
----------
app : typer.Typer
The Typer application to which the command will be registered.
**kwargs : dict
**kwargs : Any
Additional keyword arguments to be passed to the Typer command.

Returns
-------
Callable
The decorated function with help texts applied, without overwriting
existing settings.
Callable[[Callable[P, R]], Callable[P, R]]
A decorator function that takes a command (Callable[P, R]) and
returns a wrapper (Callable[P, R]) preserving its signature P and return R.
"""

def decorator(command: Callable) -> Callable:
def decorator(command: Callable[P, R]) -> Callable[P, R]: # numpydoc ignore=GL08
if command.__doc__ is None:
return command

Expand Down Expand Up @@ -66,9 +81,7 @@ def decorator(command: Callable) -> Callable:
param_type, *metadata = get_args(param_type)
new_metadata = []
for m in metadata:
if isinstance(
m, typer.models.ArgumentInfo | typer.models.OptionInfo
):
if isinstance(m, ArgumentInfo | OptionInfo):
if not m.help:
m.help = help_text
new_metadata.append(m)
Expand All @@ -77,9 +90,7 @@ def decorator(command: Callable) -> Callable:
)

# If it's an Option or Argument directly
elif isinstance(
param.default, (typer.models.ArgumentInfo, typer.models.OptionInfo)
):
elif isinstance(param.default, ArgumentInfo | OptionInfo):
if not param.default.help:
param.default.help = help_text
new_param = param
Expand All @@ -103,15 +114,18 @@ def decorator(command: Callable) -> Callable:
# Create a new signature with updated parameters
new_sig = sig.replace(parameters=new_parameters)

# Apply the new signature to the wrapper function

# Register the command with the app
# Since the signature (P, R) is applied to the decorator result,
# the wrapper's type definition must match what command returns (R).
@app.command(help=command_help.strip(), **kwargs)
@wraps(command)
def wrapper(*args: Any, **kwargs: Any) -> Any: # numpydoc ignore=GL08
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: # numpydoc ignore=GL08
return command(*args, **kwargs)

# NOTE: Typer requires the dynamic signature update for runtime reflection,
# but the type checker uses the P, R generics.
wrapper.__signature__ = new_sig

return wrapper

return decorator
30 changes: 19 additions & 11 deletions qcore/constants.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,51 @@
"""DEPRECATED - Global constants and Enum helper."""

from collections.abc import Generator
from enum import Enum
from typing import Any
from warnings import deprecated # type: ignore

from typing_extensions import deprecated # type: ignore

@deprecated

@deprecated("Use built-in Enum")
class ExtendedEnum(Enum):
"""DEPRECATED: Utility enum extension. Use built-in Enum."""

@classmethod
def has_value(cls, value):
def has_value(cls, value: Any) -> bool:
return any(value == item.value for item in cls)

@classmethod
def is_substring(cls, parent_string):
def is_substring(cls, parent_string: str) -> bool:
"""Check if an enum's string value is contained in the given string"""
return any(item.value in parent_string for item in cls)
return any(
isinstance(item.value, str) and item.value in parent_string
for item in cls
)

@classmethod
def get_names(cls):
def get_names(cls) -> list[str]:
return [item.name for item in cls]

def __str__(self):
def __str__(self) -> str:
return self.name


@deprecated
class ExtendedStrEnum(ExtendedEnum):
@deprecated("Use built-in StrEnum")
class ExtendedStrEnum(ExtendedEnum): # type: ignore
"""DEPRECATED: Utility Enum extension for string mappings. Use built-in StrEnum."""

_value_: Any
str_value: str

def __new__(cls, value: Any, str_value: str): # noqa: D102 # numpydoc ignore=GL08
obj = object.__new__(cls)
obj._value_ = value
obj.str_value = str_value
return obj

@classmethod
def has_str_value(cls, str_value):
def has_str_value(cls, str_value: str) -> bool:
return any(str_value == item.str_value for item in cls)

@classmethod
Expand All @@ -50,7 +58,7 @@ def from_str(cls, str_value):
return item

@classmethod
def iterate_str_values(cls, ignore_none=True):
def iterate_str_values(cls, ignore_none: bool = True) -> Generator[Any, None, None]:
"""Iterates over the string values of the enum,
ignores entries without a string value by default
"""
Expand Down
Loading
Loading