From 3b34509dd66ac26b507e0aa6b81fb4be754800b7 Mon Sep 17 00:00:00 2001 From: Georg Plaz Date: Thu, 24 Jul 2025 14:35:34 +0200 Subject: [PATCH] Ensure protected names (except for classes) are not imported, mainly to avoid importing _S and other TypeVars --- src/spellbind/float_values.py | 14 ++++----- src/spellbind/functions.py | 42 ++++---------------------- src/spellbind/int_collections.py | 7 +++-- src/spellbind/int_values.py | 12 ++++---- src/spellbind/numbers.py | 31 +++++++++++++++++++ src/spellbind/str_collections.py | 7 +++-- src/spellbind/values.py | 16 +++++----- tests/conftest.py | 52 +++++++++++++++++++++++++++++++- tests/test_imports.py | 13 ++++++++ 9 files changed, 132 insertions(+), 62 deletions(-) create mode 100644 src/spellbind/numbers.py create mode 100644 tests/test_imports.py diff --git a/src/spellbind/float_values.py b/src/spellbind/float_values.py index e00e7d2..6a52fb2 100644 --- a/src/spellbind/float_values.py +++ b/src/spellbind/float_values.py @@ -9,9 +9,9 @@ from typing_extensions import TYPE_CHECKING from spellbind.bool_values import BoolValue -from spellbind.functions import _clamp_float, _multiply_all_floats +from spellbind.numbers import multiply_all_floats, clamp_float from spellbind.values import Value, SimpleVariable, OneToOneValue, DerivedValueBase, Constant, \ - NotConstantError, ThreeToOneValue, _create_value_getter, get_constant_of_generic_like + NotConstantError, ThreeToOneValue, create_value_getter, get_constant_of_generic_like if TYPE_CHECKING: from spellbind.int_values import IntValue, IntLike # pragma: no cover @@ -41,10 +41,10 @@ def __rsub__(self, other: int | float) -> FloatValue: return FloatValue.derive_from_two(operator.sub, other, self) def __mul__(self, other: FloatLike) -> FloatValue: - return FloatValue.derive_from_many(_multiply_all_floats, self, other, is_associative=True) + return FloatValue.derive_from_many(multiply_all_floats, self, other, is_associative=True) def __rmul__(self, other: int | float) -> FloatValue: - return FloatValue.derive_from_many(_multiply_all_floats, other, self, is_associative=True) + return FloatValue.derive_from_many(multiply_all_floats, other, self, is_associative=True) def __truediv__(self, other: FloatLike) -> FloatValue: return FloatValue.derive_from_two(operator.truediv, self, other) @@ -110,7 +110,7 @@ def __pos__(self) -> Self: return self def clamp(self, min_value: FloatLike, max_value: FloatLike) -> FloatValue: - return FloatValue.derive_from_three_floats(_clamp_float, self, min_value, max_value) + return FloatValue.derive_from_three_floats(clamp_float, self, min_value, max_value) def decompose_float_operands(self, operator_: Callable[..., float]) -> Sequence[FloatLike]: return (self,) @@ -204,7 +204,7 @@ def sum_floats(*values: FloatLike) -> FloatValue: def multiply_floats(*values: FloatLike) -> FloatValue: - return FloatValue.derive_from_many(_multiply_all_floats, *values, is_associative=True) + return FloatValue.derive_from_many(multiply_all_floats, *values, is_associative=True) class OneToFloatValue(Generic[_S], OneToOneValue[_S, float], FloatValue): @@ -327,7 +327,7 @@ def __init__(self, transformer: Callable[[float, int], _S], self._of_first = first self._of_second = second self._first_getter = _create_float_getter(first) - self._second_getter = _create_value_getter(second) + self._second_getter = create_value_getter(second) super().__init__(*[v for v in (first, second) if isinstance(v, Value)]) @override diff --git a/src/spellbind/functions.py b/src/spellbind/functions.py index 5c148f8..e9c4139 100644 --- a/src/spellbind/functions.py +++ b/src/spellbind/functions.py @@ -1,9 +1,9 @@ import inspect from inspect import Parameter -from typing import Callable, Iterable, Any +from typing import Callable, Any -def _is_positional_parameter(param: Parameter) -> bool: +def is_positional_parameter(param: Parameter) -> bool: return param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD) @@ -14,16 +14,16 @@ def has_var_args(function: Callable[..., Any]) -> bool: def count_positional_parameters(function: Callable[..., Any]) -> int: parameters = inspect.signature(function).parameters - return sum(1 for parameter in parameters.values() if _is_positional_parameter(parameter)) + return sum(1 for parameter in parameters.values() if is_positional_parameter(parameter)) -def _is_required_positional_parameter(param: Parameter) -> bool: - return param.default == param.empty and _is_positional_parameter(param) +def is_required_positional_parameter(param: Parameter) -> bool: + return param.default == param.empty and is_positional_parameter(param) def count_non_default_parameters(function: Callable[..., Any]) -> int: parameters = inspect.signature(function).parameters - return sum(1 for param in parameters.values() if _is_required_positional_parameter(param)) + return sum(1 for param in parameters.values() if is_required_positional_parameter(param)) def assert_parameter_max_count(callable_: Callable[..., Any], max_count: int) -> None: @@ -36,33 +36,3 @@ def assert_parameter_max_count(callable_: Callable[..., Any], max_count: int) -> callable_name = str(callable_) # pragma: no cover raise ValueError(f"Callable {callable_name} has too many non-default parameters: " f"{count_non_default_parameters(callable_)} > {max_count}") - - -def _multiply_all_ints(vals: Iterable[int]) -> int: - result = 1 - for val in vals: - result *= val - return result - - -def _multiply_all_floats(vals: Iterable[float]) -> float: - result = 1. - for val in vals: - result *= val - return result - - -def _clamp_int(value: int, min_value: int, max_value: int) -> int: - if value < min_value: - return min_value - elif value > max_value: - return max_value - return value - - -def _clamp_float(value: float, min_value: float, max_value: float) -> float: - if value < min_value: - return min_value - elif value > max_value: - return max_value - return value diff --git a/src/spellbind/int_collections.py b/src/spellbind/int_collections.py index 9e9ff10..42acdaa 100644 --- a/src/spellbind/int_collections.py +++ b/src/spellbind/int_collections.py @@ -3,17 +3,20 @@ import operator from abc import ABC, abstractmethod from functools import cached_property -from typing import Iterable, Callable, Any +from typing import Iterable, Callable, Any, TypeVar from typing_extensions import TypeIs, override from spellbind.int_values import IntValue, IntConstant from spellbind.observable_collections import ObservableCollection, ReducedValue, CombinedValue, ValueCollection -from spellbind.observable_sequences import ObservableList, _S, TypedValueList, ValueSequence, UnboxedValueSequence, \ +from spellbind.observable_sequences import ObservableList, TypedValueList, ValueSequence, UnboxedValueSequence, \ ObservableSequence from spellbind.values import Value +_S = TypeVar("_S") + + class ObservableIntCollection(ObservableCollection[int], ABC): @property def summed(self) -> IntValue: diff --git a/src/spellbind/int_values.py b/src/spellbind/int_values.py index 3402168..2d674e8 100644 --- a/src/spellbind/int_values.py +++ b/src/spellbind/int_values.py @@ -9,7 +9,7 @@ from spellbind.bool_values import BoolValue from spellbind.float_values import FloatValue, \ CompareNumbersValues -from spellbind.functions import _clamp_int, _multiply_all_ints, _multiply_all_floats +from spellbind.numbers import multiply_all_ints, multiply_all_floats, clamp_int from spellbind.values import Value, SimpleVariable, TwoToOneValue, OneToOneValue, Constant, \ ThreeToOneValue, NotConstantError, ManyToSameValue, get_constant_of_generic_like @@ -75,8 +75,8 @@ def __mul__(self, other: float | FloatValue) -> FloatValue: ... def __mul__(self, other: FloatLike) -> IntValue | FloatValue: if isinstance(other, (float, FloatValue)): - return FloatValue.derive_from_many(_multiply_all_floats, self, other, is_associative=True) - return IntValue.derive_from_many(_multiply_all_ints, self, other, is_associative=True) + return FloatValue.derive_from_many(multiply_all_floats, self, other, is_associative=True) + return IntValue.derive_from_many(multiply_all_ints, self, other, is_associative=True) @overload def __rmul__(self, other: int) -> IntValue: ... @@ -86,8 +86,8 @@ def __rmul__(self, other: float) -> FloatValue: ... def __rmul__(self, other: int | float) -> IntValue | FloatValue: if isinstance(other, float): - return FloatValue.derive_from_many(_multiply_all_floats, other, self, is_associative=True) - return IntValue.derive_from_many(_multiply_all_ints, other, self, is_associative=True) + return FloatValue.derive_from_many(multiply_all_floats, other, self, is_associative=True) + return IntValue.derive_from_many(multiply_all_ints, other, self, is_associative=True) def __truediv__(self, other: FloatLike) -> FloatValue: return FloatValue.derive_from_two(operator.truediv, self, other) @@ -135,7 +135,7 @@ def __pos__(self) -> Self: return self def clamp(self, min_value: IntLike, max_value: IntLike) -> IntValue: - return IntValue.derive_from_three(_clamp_int, self, min_value, max_value) + return IntValue.derive_from_three(clamp_int, self, min_value, max_value) @classmethod def derive_from_one(cls, operator_: Callable[[_S], int], value: _S | Value[_S]) -> IntValue: diff --git a/src/spellbind/numbers.py b/src/spellbind/numbers.py new file mode 100644 index 0000000..b0da64b --- /dev/null +++ b/src/spellbind/numbers.py @@ -0,0 +1,31 @@ +from typing import Iterable + + +def multiply_all_ints(vals: Iterable[int]) -> int: + result = 1 + for val in vals: + result *= val + return result + + +def multiply_all_floats(vals: Iterable[float]) -> float: + result = 1. + for val in vals: + result *= val + return result + + +def clamp_int(value: int, min_value: int, max_value: int) -> int: + if value < min_value: + return min_value + elif value > max_value: + return max_value + return value + + +def clamp_float(value: float, min_value: float, max_value: float) -> float: + if value < min_value: + return min_value + elif value > max_value: + return max_value + return value diff --git a/src/spellbind/str_collections.py b/src/spellbind/str_collections.py index 4603898..d833a2c 100644 --- a/src/spellbind/str_collections.py +++ b/src/spellbind/str_collections.py @@ -1,15 +1,18 @@ from abc import ABC -from typing import Iterable, Callable, Any +from typing import Iterable, Callable, Any, TypeVar from typing_extensions import TypeIs from spellbind.int_values import IntValue from spellbind.observable_collections import ObservableCollection, ReducedValue, CombinedValue -from spellbind.observable_sequences import ObservableList, _S, TypedValueList +from spellbind.observable_sequences import ObservableList, TypedValueList from spellbind.str_values import StrValue, StrConstant from spellbind.values import Value +_S = TypeVar("_S") + + class ObservableStrCollection(ObservableCollection[str], ABC): @property def concatenated(self) -> StrValue: diff --git a/src/spellbind/values.py b/src/spellbind/values.py index a42a336..0d2250a 100644 --- a/src/spellbind/values.py +++ b/src/spellbind/values.py @@ -27,7 +27,7 @@ _W = TypeVar("_W") -def _create_value_getter(value: Value[_S] | _S) -> Callable[[], _S]: +def create_value_getter(value: Value[_S] | _S) -> Callable[[], _S]: if isinstance(value, Value): return lambda: value.value else: @@ -385,7 +385,7 @@ class OneToOneValue(DerivedValueBase[_T], Generic[_S, _T]): _getter: Callable[[], _S] def __init__(self, transformer: Callable[[_S], _T], of: Value[_S]) -> None: - self._getter = _create_value_getter(of) + self._getter = create_value_getter(of) self._of = of self._transformer = transformer super().__init__(*[v for v in (of,) if isinstance(v, Value)]) @@ -398,7 +398,7 @@ def _calculate_value(self) -> _T: class ManyToOneValue(DerivedValueBase[_T], Generic[_S, _T]): def __init__(self, transformer: Callable[[Iterable[_S]], _T], *values: _S | Value[_S]): self._input_values = tuple(values) - self._value_getters = [_create_value_getter(v) for v in self._input_values] + self._value_getters = [create_value_getter(v) for v in self._input_values] self._transformer = transformer super().__init__(*[v for v in self._input_values if isinstance(v, Value)]) @@ -422,8 +422,8 @@ def __init__(self, transformer: Callable[[_S, _T], _U], self._transformer = transformer self._of_first = first self._of_second = second - self._first_getter = _create_value_getter(first) - self._second_getter = _create_value_getter(second) + self._first_getter = create_value_getter(first) + self._second_getter = create_value_getter(second) super().__init__(*[v for v in (first, second) if isinstance(v, Value)]) @override @@ -438,9 +438,9 @@ def __init__(self, transformer: Callable[[_S, _T, _U], _V], self._of_first = first self._of_second = second self._of_third = third - self._first_getter = _create_value_getter(first) - self._second_getter = _create_value_getter(second) - self._third_getter = _create_value_getter(third) + self._first_getter = create_value_getter(first) + self._second_getter = create_value_getter(second) + self._third_getter = create_value_getter(third) super().__init__(*[v for v in (first, second, third) if isinstance(v, Value)]) @override diff --git a/tests/conftest.py b/tests/conftest.py index 3142381..25e3529 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,7 @@ +import ast from contextlib import contextmanager -from typing import Any, Sequence, Callable +from pathlib import Path +from typing import Any, Sequence, Callable, Generator, Tuple from typing import Iterable, overload, Collection from unittest.mock import Mock @@ -13,6 +15,54 @@ _S = TypeVar("_S") +PROJECT_ROOT_PATH = Path(__file__).parent.parent.resolve() +SOURCE_PATH = PROJECT_ROOT_PATH / "src" + + +def iter_python_files(source_path: Path) -> Generator[Path, None, None]: + yield from source_path.rglob("*.py") + + +def is_class_definition(module_path: Path, object_name: str) -> bool: + text = module_path.read_text(encoding="utf-8") + node = ast.parse(text, filename=str(module_path)) + for item in node.body: + if hasattr(item, "name") and getattr(item, "name") == object_name: + if isinstance(item, ast.ClassDef): + return True + else: + return False + return False + + +def resolve_module_path(base_path: Path, module: str) -> Path: + unfinished_module_path = base_path / Path(*module.split(".")) + init_path = unfinished_module_path / "__init__.py" + if init_path.exists(): + return init_path + file_path = unfinished_module_path.with_suffix(".py") + return file_path + + +def is_class_import(alias: ast.alias, import_: ast.ImportFrom, source_root: Path = SOURCE_PATH) -> bool: + module = import_.module + if module is None: + return False + module_path = resolve_module_path(source_root, module) + if module_path is None: + return False + return is_class_definition(module_path, alias.name) + + +def iter_imported_aliases(file_path: Path) -> Generator[Tuple[ast.alias, ast.ImportFrom], None, None]: + text = file_path.read_text(encoding="utf-8") + node = ast.parse(text, filename=str(file_path)) + for statement in ast.walk(node): + if isinstance(statement, ast.ImportFrom): + for alias_ in statement.names: + yield alias_, statement + + class Call: def __init__(self, *args, **kwargs): self.args = args diff --git a/tests/test_imports.py b/tests/test_imports.py new file mode 100644 index 0000000..46afa7d --- /dev/null +++ b/tests/test_imports.py @@ -0,0 +1,13 @@ +import pytest + +from conftest import iter_imported_aliases, SOURCE_PATH, iter_python_files, is_class_import + + +def test_no_protected_imports_except_for_classes(): + lines = [] + for file in iter_python_files(SOURCE_PATH): + for alias_, statement in iter_imported_aliases(file): + if alias_.name.startswith("_") and not is_class_import(alias_, statement): + lines.append(f"{file}:{statement.lineno}: imports protected name '{alias_.name}'") + if len(lines) > 0: + pytest.fail(f"Found {len(lines)} protected imports\n" + "\n".join(lines))