Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/spellbind/float_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from typing_extensions import TYPE_CHECKING

from spellbind.bool_values import BoolValue
from spellbind.functions import _clamp_float, _multiply_all_floats
from spellbind.numbers import multiply_all_floats, clamp_float
from spellbind.values import Value, SimpleVariable, OneToOneValue, DerivedValueBase, Constant, \
NotConstantError, ThreeToOneValue, _create_value_getter, get_constant_of_generic_like
NotConstantError, ThreeToOneValue, create_value_getter, get_constant_of_generic_like

if TYPE_CHECKING:
from spellbind.int_values import IntValue, IntLike # pragma: no cover
Expand Down Expand Up @@ -41,10 +41,10 @@ def __rsub__(self, other: int | float) -> FloatValue:
return FloatValue.derive_from_two(operator.sub, other, self)

def __mul__(self, other: FloatLike) -> FloatValue:
return FloatValue.derive_from_many(_multiply_all_floats, self, other, is_associative=True)
return FloatValue.derive_from_many(multiply_all_floats, self, other, is_associative=True)

def __rmul__(self, other: int | float) -> FloatValue:
return FloatValue.derive_from_many(_multiply_all_floats, other, self, is_associative=True)
return FloatValue.derive_from_many(multiply_all_floats, other, self, is_associative=True)

def __truediv__(self, other: FloatLike) -> FloatValue:
return FloatValue.derive_from_two(operator.truediv, self, other)
Expand Down Expand Up @@ -110,7 +110,7 @@ def __pos__(self) -> Self:
return self

def clamp(self, min_value: FloatLike, max_value: FloatLike) -> FloatValue:
return FloatValue.derive_from_three_floats(_clamp_float, self, min_value, max_value)
return FloatValue.derive_from_three_floats(clamp_float, self, min_value, max_value)

def decompose_float_operands(self, operator_: Callable[..., float]) -> Sequence[FloatLike]:
return (self,)
Expand Down Expand Up @@ -204,7 +204,7 @@ def sum_floats(*values: FloatLike) -> FloatValue:


def multiply_floats(*values: FloatLike) -> FloatValue:
return FloatValue.derive_from_many(_multiply_all_floats, *values, is_associative=True)
return FloatValue.derive_from_many(multiply_all_floats, *values, is_associative=True)


class OneToFloatValue(Generic[_S], OneToOneValue[_S, float], FloatValue):
Expand Down Expand Up @@ -327,7 +327,7 @@ def __init__(self, transformer: Callable[[float, int], _S],
self._of_first = first
self._of_second = second
self._first_getter = _create_float_getter(first)
self._second_getter = _create_value_getter(second)
self._second_getter = create_value_getter(second)
super().__init__(*[v for v in (first, second) if isinstance(v, Value)])

@override
Expand Down
42 changes: 6 additions & 36 deletions src/spellbind/functions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import inspect
from inspect import Parameter
from typing import Callable, Iterable, Any
from typing import Callable, Any


def _is_positional_parameter(param: Parameter) -> bool:
def is_positional_parameter(param: Parameter) -> bool:
return param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)


Expand All @@ -14,16 +14,16 @@ def has_var_args(function: Callable[..., Any]) -> bool:

def count_positional_parameters(function: Callable[..., Any]) -> int:
parameters = inspect.signature(function).parameters
return sum(1 for parameter in parameters.values() if _is_positional_parameter(parameter))
return sum(1 for parameter in parameters.values() if is_positional_parameter(parameter))


def _is_required_positional_parameter(param: Parameter) -> bool:
return param.default == param.empty and _is_positional_parameter(param)
def is_required_positional_parameter(param: Parameter) -> bool:
return param.default == param.empty and is_positional_parameter(param)


def count_non_default_parameters(function: Callable[..., Any]) -> int:
parameters = inspect.signature(function).parameters
return sum(1 for param in parameters.values() if _is_required_positional_parameter(param))
return sum(1 for param in parameters.values() if is_required_positional_parameter(param))


def assert_parameter_max_count(callable_: Callable[..., Any], max_count: int) -> None:
Expand All @@ -36,33 +36,3 @@ def assert_parameter_max_count(callable_: Callable[..., Any], max_count: int) ->
callable_name = str(callable_) # pragma: no cover
raise ValueError(f"Callable {callable_name} has too many non-default parameters: "
f"{count_non_default_parameters(callable_)} > {max_count}")


def _multiply_all_ints(vals: Iterable[int]) -> int:
result = 1
for val in vals:
result *= val
return result


def _multiply_all_floats(vals: Iterable[float]) -> float:
result = 1.
for val in vals:
result *= val
return result


def _clamp_int(value: int, min_value: int, max_value: int) -> int:
if value < min_value:
return min_value
elif value > max_value:
return max_value
return value


def _clamp_float(value: float, min_value: float, max_value: float) -> float:
if value < min_value:
return min_value
elif value > max_value:
return max_value
return value
7 changes: 5 additions & 2 deletions src/spellbind/int_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@
import operator
from abc import ABC, abstractmethod
from functools import cached_property
from typing import Iterable, Callable, Any
from typing import Iterable, Callable, Any, TypeVar

from typing_extensions import TypeIs, override

from spellbind.int_values import IntValue, IntConstant
from spellbind.observable_collections import ObservableCollection, ReducedValue, CombinedValue, ValueCollection
from spellbind.observable_sequences import ObservableList, _S, TypedValueList, ValueSequence, UnboxedValueSequence, \
from spellbind.observable_sequences import ObservableList, TypedValueList, ValueSequence, UnboxedValueSequence, \
ObservableSequence
from spellbind.values import Value


_S = TypeVar("_S")


class ObservableIntCollection(ObservableCollection[int], ABC):
@property
def summed(self) -> IntValue:
Expand Down
12 changes: 6 additions & 6 deletions src/spellbind/int_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from spellbind.bool_values import BoolValue
from spellbind.float_values import FloatValue, \
CompareNumbersValues
from spellbind.functions import _clamp_int, _multiply_all_ints, _multiply_all_floats
from spellbind.numbers import multiply_all_ints, multiply_all_floats, clamp_int
from spellbind.values import Value, SimpleVariable, TwoToOneValue, OneToOneValue, Constant, \
ThreeToOneValue, NotConstantError, ManyToSameValue, get_constant_of_generic_like

Expand Down Expand Up @@ -75,8 +75,8 @@ def __mul__(self, other: float | FloatValue) -> FloatValue: ...

def __mul__(self, other: FloatLike) -> IntValue | FloatValue:
if isinstance(other, (float, FloatValue)):
return FloatValue.derive_from_many(_multiply_all_floats, self, other, is_associative=True)
return IntValue.derive_from_many(_multiply_all_ints, self, other, is_associative=True)
return FloatValue.derive_from_many(multiply_all_floats, self, other, is_associative=True)
return IntValue.derive_from_many(multiply_all_ints, self, other, is_associative=True)

@overload
def __rmul__(self, other: int) -> IntValue: ...
Expand All @@ -86,8 +86,8 @@ def __rmul__(self, other: float) -> FloatValue: ...

def __rmul__(self, other: int | float) -> IntValue | FloatValue:
if isinstance(other, float):
return FloatValue.derive_from_many(_multiply_all_floats, other, self, is_associative=True)
return IntValue.derive_from_many(_multiply_all_ints, other, self, is_associative=True)
return FloatValue.derive_from_many(multiply_all_floats, other, self, is_associative=True)
return IntValue.derive_from_many(multiply_all_ints, other, self, is_associative=True)

def __truediv__(self, other: FloatLike) -> FloatValue:
return FloatValue.derive_from_two(operator.truediv, self, other)
Expand Down Expand Up @@ -135,7 +135,7 @@ def __pos__(self) -> Self:
return self

def clamp(self, min_value: IntLike, max_value: IntLike) -> IntValue:
return IntValue.derive_from_three(_clamp_int, self, min_value, max_value)
return IntValue.derive_from_three(clamp_int, self, min_value, max_value)

@classmethod
def derive_from_one(cls, operator_: Callable[[_S], int], value: _S | Value[_S]) -> IntValue:
Expand Down
31 changes: 31 additions & 0 deletions src/spellbind/numbers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Iterable


def multiply_all_ints(vals: Iterable[int]) -> int:
result = 1
for val in vals:
result *= val
return result


def multiply_all_floats(vals: Iterable[float]) -> float:
result = 1.
for val in vals:
result *= val
return result


def clamp_int(value: int, min_value: int, max_value: int) -> int:
if value < min_value:
return min_value
elif value > max_value:
return max_value
return value


def clamp_float(value: float, min_value: float, max_value: float) -> float:
if value < min_value:
return min_value
elif value > max_value:
return max_value
return value
7 changes: 5 additions & 2 deletions src/spellbind/str_collections.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from abc import ABC
from typing import Iterable, Callable, Any
from typing import Iterable, Callable, Any, TypeVar

from typing_extensions import TypeIs

from spellbind.int_values import IntValue
from spellbind.observable_collections import ObservableCollection, ReducedValue, CombinedValue
from spellbind.observable_sequences import ObservableList, _S, TypedValueList
from spellbind.observable_sequences import ObservableList, TypedValueList
from spellbind.str_values import StrValue, StrConstant
from spellbind.values import Value


_S = TypeVar("_S")


class ObservableStrCollection(ObservableCollection[str], ABC):
@property
def concatenated(self) -> StrValue:
Expand Down
16 changes: 8 additions & 8 deletions src/spellbind/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
_W = TypeVar("_W")


def _create_value_getter(value: Value[_S] | _S) -> Callable[[], _S]:
def create_value_getter(value: Value[_S] | _S) -> Callable[[], _S]:
if isinstance(value, Value):
return lambda: value.value
else:
Expand Down Expand Up @@ -385,7 +385,7 @@ class OneToOneValue(DerivedValueBase[_T], Generic[_S, _T]):
_getter: Callable[[], _S]

def __init__(self, transformer: Callable[[_S], _T], of: Value[_S]) -> None:
self._getter = _create_value_getter(of)
self._getter = create_value_getter(of)
self._of = of
self._transformer = transformer
super().__init__(*[v for v in (of,) if isinstance(v, Value)])
Expand All @@ -398,7 +398,7 @@ def _calculate_value(self) -> _T:
class ManyToOneValue(DerivedValueBase[_T], Generic[_S, _T]):
def __init__(self, transformer: Callable[[Iterable[_S]], _T], *values: _S | Value[_S]):
self._input_values = tuple(values)
self._value_getters = [_create_value_getter(v) for v in self._input_values]
self._value_getters = [create_value_getter(v) for v in self._input_values]
self._transformer = transformer
super().__init__(*[v for v in self._input_values if isinstance(v, Value)])

Expand All @@ -422,8 +422,8 @@ def __init__(self, transformer: Callable[[_S, _T], _U],
self._transformer = transformer
self._of_first = first
self._of_second = second
self._first_getter = _create_value_getter(first)
self._second_getter = _create_value_getter(second)
self._first_getter = create_value_getter(first)
self._second_getter = create_value_getter(second)
super().__init__(*[v for v in (first, second) if isinstance(v, Value)])

@override
Expand All @@ -438,9 +438,9 @@ def __init__(self, transformer: Callable[[_S, _T, _U], _V],
self._of_first = first
self._of_second = second
self._of_third = third
self._first_getter = _create_value_getter(first)
self._second_getter = _create_value_getter(second)
self._third_getter = _create_value_getter(third)
self._first_getter = create_value_getter(first)
self._second_getter = create_value_getter(second)
self._third_getter = create_value_getter(third)
super().__init__(*[v for v in (first, second, third) if isinstance(v, Value)])

@override
Expand Down
52 changes: 51 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import ast
from contextlib import contextmanager
from typing import Any, Sequence, Callable
from pathlib import Path
from typing import Any, Sequence, Callable, Generator, Tuple
from typing import Iterable, overload, Collection
from unittest.mock import Mock

Expand All @@ -13,6 +15,54 @@
_S = TypeVar("_S")


PROJECT_ROOT_PATH = Path(__file__).parent.parent.resolve()
SOURCE_PATH = PROJECT_ROOT_PATH / "src"


def iter_python_files(source_path: Path) -> Generator[Path, None, None]:
yield from source_path.rglob("*.py")


def is_class_definition(module_path: Path, object_name: str) -> bool:
text = module_path.read_text(encoding="utf-8")
node = ast.parse(text, filename=str(module_path))
for item in node.body:
if hasattr(item, "name") and getattr(item, "name") == object_name:
if isinstance(item, ast.ClassDef):
return True
else:
return False
return False


def resolve_module_path(base_path: Path, module: str) -> Path:
unfinished_module_path = base_path / Path(*module.split("."))
init_path = unfinished_module_path / "__init__.py"
if init_path.exists():
return init_path
file_path = unfinished_module_path.with_suffix(".py")
return file_path


def is_class_import(alias: ast.alias, import_: ast.ImportFrom, source_root: Path = SOURCE_PATH) -> bool:
module = import_.module
if module is None:
return False
module_path = resolve_module_path(source_root, module)
if module_path is None:
return False
return is_class_definition(module_path, alias.name)


def iter_imported_aliases(file_path: Path) -> Generator[Tuple[ast.alias, ast.ImportFrom], None, None]:
text = file_path.read_text(encoding="utf-8")
node = ast.parse(text, filename=str(file_path))
for statement in ast.walk(node):
if isinstance(statement, ast.ImportFrom):
for alias_ in statement.names:
yield alias_, statement


class Call:
def __init__(self, *args, **kwargs):
self.args = args
Expand Down
13 changes: 13 additions & 0 deletions tests/test_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest

from conftest import iter_imported_aliases, SOURCE_PATH, iter_python_files, is_class_import


def test_no_protected_imports_except_for_classes():
lines = []
for file in iter_python_files(SOURCE_PATH):
for alias_, statement in iter_imported_aliases(file):
if alias_.name.startswith("_") and not is_class_import(alias_, statement):
lines.append(f"{file}:{statement.lineno}: imports protected name '{alias_.name}'")
if len(lines) > 0:
pytest.fail(f"Found {len(lines)} protected imports\n" + "\n".join(lines))