diff --git a/src/spellbind/float_collections.py b/src/spellbind/float_collections.py new file mode 100644 index 0000000..ed6ef04 --- /dev/null +++ b/src/spellbind/float_collections.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import operator +from abc import ABC, abstractmethod +from functools import cached_property +from typing import Iterable, Callable, Any, TypeVar + +from typing_extensions import TypeIs, override + +from spellbind.float_values import FloatValue, FloatConstant +from spellbind.observable_collections import ObservableCollection, ReducedValue, CombinedValue, ValueCollection, \ + MappedObservableBag +from spellbind.observable_sequences import ObservableList, TypedValueList, ValueSequence, UnboxedValueSequence, \ + ObservableSequence +from spellbind.values import Value + + +_S = TypeVar("_S") + + +class ObservableFloatCollection(ObservableCollection[float], ABC): + @property + def summed(self) -> FloatValue: + return self.reduce_to_float(add_reducer=operator.add, remove_reducer=operator.sub, initial=0.0) + + @property + def multiplied(self) -> FloatValue: + return self.reduce_to_float(add_reducer=operator.mul, remove_reducer=operator.truediv, initial=1.0) + + +class MappedToFloatBag(MappedObservableBag[float], ObservableFloatCollection): + pass + + +class ObservableFloatSequence(ObservableSequence[float], ObservableFloatCollection, ABC): + pass + + +class ObservableFloatList(ObservableList[float], ObservableFloatSequence): + pass + + +class FloatValueCollection(ValueCollection[float], ABC): + @property + def summed(self) -> FloatValue: + return self.unboxed.reduce_to_float(add_reducer=operator.add, remove_reducer=operator.sub, initial=0.0) + + @property + @abstractmethod + def unboxed(self) -> ObservableFloatCollection: ... + + +class CombinedFloatValue(CombinedValue[float], FloatValue): + def __init__(self, collection: ObservableCollection[_S], combiner: Callable[[Iterable[_S]], float]) -> None: + super().__init__(collection=collection, combiner=combiner) + + +class ReducedFloatValue(ReducedValue[float], FloatValue): + def __init__(self, + collection: ObservableCollection[_S], + add_reducer: Callable[[float, _S], float], + remove_reducer: Callable[[float, _S], float], + initial: float): + super().__init__(collection=collection, + add_reducer=add_reducer, + remove_reducer=remove_reducer, + initial=initial) + + +class UnboxedFloatValueSequence(UnboxedValueSequence[float], ObservableFloatSequence): + def __init__(self, sequence: FloatValueSequence) -> None: + super().__init__(sequence) + + +class FloatValueSequence(ValueSequence[float], FloatValueCollection, ABC): + @cached_property + @override + def unboxed(self) -> ObservableFloatSequence: + return UnboxedFloatValueSequence(self) + + +class FloatValueList(TypedValueList[float], FloatValueSequence): + def __init__(self, values: Iterable[float | Value[float]] | None = None): + def is_float(value: Any) -> TypeIs[float]: + return isinstance(value, float) + super().__init__(values, checker=is_float, constant_factory=FloatConstant.of) diff --git a/src/spellbind/int_collections.py b/src/spellbind/int_collections.py index 42acdaa..e76e13a 100644 --- a/src/spellbind/int_collections.py +++ b/src/spellbind/int_collections.py @@ -7,8 +7,10 @@ from typing_extensions import TypeIs, override +from spellbind.float_values import FloatValue from spellbind.int_values import IntValue, IntConstant -from spellbind.observable_collections import ObservableCollection, ReducedValue, CombinedValue, ValueCollection +from spellbind.observable_collections import ObservableCollection, ReducedValue, CombinedValue, ValueCollection, \ + MappedObservableBag from spellbind.observable_sequences import ObservableList, TypedValueList, ValueSequence, UnboxedValueSequence, \ ObservableSequence from spellbind.values import Value @@ -27,6 +29,10 @@ def multiplied(self) -> IntValue: return self.reduce_to_int(add_reducer=operator.mul, remove_reducer=operator.floordiv, initial=1) +class MappedToIntBag(MappedObservableBag[int], ObservableIntCollection): + pass + + class ObservableIntSequence(ObservableSequence[int], ObservableIntCollection, ABC): pass @@ -50,6 +56,11 @@ def __init__(self, collection: ObservableCollection[_S], combiner: Callable[[Ite super().__init__(collection=collection, combiner=combiner) +class CombinedFloatValue(CombinedValue[float], FloatValue): + def __init__(self, collection: ObservableCollection[_S], combiner: Callable[[Iterable[_S]], float]) -> None: + super().__init__(collection=collection, combiner=combiner) + + class ReducedIntValue(ReducedValue[int], IntValue): def __init__(self, collection: ObservableCollection[_S], diff --git a/src/spellbind/observable_collections.py b/src/spellbind/observable_collections.py index 12d56d6..20eaf20 100644 --- a/src/spellbind/observable_collections.py +++ b/src/spellbind/observable_collections.py @@ -1,8 +1,9 @@ from __future__ import annotations import functools +import logging from abc import ABC, abstractmethod -from typing import TypeVar, Generic, Collection, Callable, Iterable, Iterator, Any +from typing import TypeVar, Generic, Collection, Callable, Iterable, Iterator, Any, TYPE_CHECKING from typing_extensions import override @@ -11,16 +12,24 @@ from spellbind.bool_values import BoolValue from spellbind.deriveds import Derived from spellbind.event import BiEvent, ValueEvent +from spellbind.float_values import FloatValue from spellbind.int_values import IntValue, IntVariable from spellbind.observables import ValuesObservable, ValueObservable, Observer, ValueObserver, BiObserver, \ Subscription from spellbind.str_values import StrValue from spellbind.values import Value, EMPTY_FROZEN_SET +if TYPE_CHECKING: + from spellbind.float_collections import ObservableFloatCollection + from spellbind.int_collections import ObservableIntCollection + + _S = TypeVar("_S") _S_co = TypeVar("_S_co", covariant=True) _T = TypeVar("_T") +_logger = logging.getLogger(__name__) + class ObservableCollection(Collection[_S_co], Generic[_S_co], ABC): @property @@ -56,6 +65,11 @@ def combine_to_int(self, combiner: Callable[[Iterable[_S_co]], int]) -> IntValue return CombinedIntValue(self, combiner=combiner) + def combine_to_float(self, combiner: Callable[[Iterable[_S_co]], float]) -> 'FloatValue': + from spellbind.float_collections import CombinedFloatValue + + return CombinedFloatValue(self, combiner=combiner) + def reduce(self, add_reducer: Callable[[_T, _S_co], _T], remove_reducer: Callable[[_T, _S_co], _T], @@ -79,7 +93,7 @@ def reduce_to_str(self, def reduce_to_int(self, add_reducer: Callable[[int, _S_co], int], remove_reducer: Callable[[int, _S_co], int], - initial: int) -> IntValue: + initial: int = 0) -> IntValue: from spellbind.int_collections import ReducedIntValue return ReducedIntValue(self, @@ -87,9 +101,33 @@ def reduce_to_int(self, remove_reducer=remove_reducer, initial=initial) - def filter_to_bag(self, predicate: Callable[[_S_co], bool]) -> FilteredObservableBag[_S_co]: + def reduce_to_float(self, + add_reducer: Callable[[float, _S_co], float], + remove_reducer: Callable[[float, _S_co], float], + initial: float = 0.) -> FloatValue: + from spellbind.float_collections import ReducedFloatValue + + return ReducedFloatValue(self, + add_reducer=add_reducer, + remove_reducer=remove_reducer, + initial=initial) + + def filter_to_bag(self, predicate: Callable[[_S_co], bool]) -> ObservableCollection[_S_co]: return FilteredObservableBag(self, predicate) + def map(self, transform: Callable[[_S_co], _T]) -> ObservableCollection[_T]: + return MappedObservableBag(self, transform) + + def map_to_float(self, transform: Callable[[_S_co], float]) -> ObservableFloatCollection: + from spellbind.float_collections import MappedToFloatBag + + return MappedToFloatBag(self, transform) + + def map_to_int(self, transform: Callable[[_S_co], int]) -> ObservableIntCollection: + from spellbind.int_collections import MappedToIntBag + + return MappedToIntBag(self, transform) + class ReducedValue(Value[_S], Generic[_S]): def __init__(self, @@ -156,18 +194,6 @@ def is_observed(self, by: Callable[..., Any] | None = None) -> bool: class ValueCollection(ObservableCollection[Value[_S]], Generic[_S], ABC): - @override - def reduce_to_int(self, - add_reducer: Callable[[int, Value[_S]], int], - remove_reducer: Callable[[int, Value[_S]], int], - initial: int) -> IntValue: - from spellbind.int_collections import ReducedIntValue - - return ReducedIntValue(self, - add_reducer=add_reducer, - remove_reducer=remove_reducer, - initial=initial) - def value_iterable(self) -> Iterable[_S]: return (value.value for value in self) @@ -281,6 +307,52 @@ def _clear(self) -> None: self._action_event(clear_action()) +class MappedObservableBag(_ObservableBagBase[_S], Generic[_S]): + def __init__(self, source: ObservableCollection[_T], transform: Callable[[_T], _S]) -> None: + super().__init__(tuple(transform(item) for item in source)) + self._source = source + self._transform = transform + + self._source.on_change.observe(self._on_source_action) + + def _on_source_action(self, action: CollectionAction[Any]) -> None: + if isinstance(action, ClearAction): + self._clear() + elif isinstance(action, ReverseAction): + pass + elif isinstance(action, DeltasAction): + mapped_action = action.map(self._transform) + total_count = self._len_value.value + for delta in mapped_action.delta_actions: + if delta.is_add: + self._item_counts[delta.value] = self._item_counts.get(delta.value, 0) + 1 + total_count += 1 + else: + count = self._item_counts.get(delta.value, 0) + if count > 0: + if count == 1: + del self._item_counts[delta.value] + else: + self._item_counts[delta.value] = count - 1 + total_count -= 1 + else: + _logger.warning( + f"Attempted to remove {delta.value!r} from {self.__class__.__name__}, " + f"but item not present. Source collection may be inconsistent with the mapped collection." + ) + + if self._is_observed(): + with self._len_value.set_delay_notify(total_count): + self._action_event(mapped_action) + self._deltas_event(mapped_action) + else: + self._len_value.value = total_count + + @override + def __repr__(self) -> str: + return f"{self.__class__.__name__}({list(self)!r})" + + class FilteredObservableBag(_ObservableBagBase[_S], Generic[_S]): def __init__(self, source: ObservableCollection[_S], predicate: Callable[[_S], bool]) -> None: super().__init__(tuple(item for item in source if predicate(item))) @@ -310,6 +382,11 @@ def _on_source_action(self, action: CollectionAction[_S]) -> None: else: self._item_counts[delta.value] = count - 1 total_count -= 1 + else: + _logger.warning( + f"Attempted to remove {delta.value!r} from {self.__class__.__name__}, " + f"but item not present. Source collection may be inconsistent with the filtered collection." + ) if self._is_observed(): with self._len_value.set_delay_notify(total_count): diff --git a/src/spellbind/observable_sequences.py b/src/spellbind/observable_sequences.py index f594911..a6e82a6 100644 --- a/src/spellbind/observable_sequences.py +++ b/src/spellbind/observable_sequences.py @@ -35,6 +35,7 @@ class ObservableSequence(Sequence[_S_co], ObservableCollection[_S_co], Generic[_ def on_change(self) -> ValueObservable[AtIndicesDeltasAction[_S_co] | ClearAction[_S_co] | ReverseAction[_S_co] | ElementsChangedAction[_S_co]]: ... @abstractmethod + @override def map(self, transformer: Callable[[_S_co], _T]) -> ObservableSequence[_T]: ... @override @@ -848,25 +849,25 @@ def __eq__(self, other: object) -> bool: class MappedIndexObservableSequence(IndexObservableSequenceBase[_S], Generic[_S]): - def __init__(self, mapped_from: IndexObservableSequence[_T], map_func: Callable[[_T], _S]) -> None: - super().__init__(map_func(item) for item in mapped_from) - self._mapped_from = mapped_from - self._map_func = map_func + def __init__(self, source: IndexObservableSequence[_T], transform: Callable[[_T], _S]) -> None: + super().__init__(transform(item) for item in source) + self._source = source + self._transform = transform def on_action(other_action: AtIndicesDeltasAction[_T] | ClearAction[_T] | ReverseAction[_T]) -> None: if isinstance(other_action, AtIndicesDeltasAction): if isinstance(other_action, ExtendAction): - self._extend((self._map_func(item) for item in other_action.items)) + self._extend((self._transform(item) for item in other_action.items)) else: for delta in other_action.delta_actions: if delta.is_add: - value: _S = self._map_func(delta.value) + value: _S = self._transform(delta.value) self._values.insert(delta.index, value) else: del self._values[delta.index] if self._is_observed(): with self._len_value.set_delay_notify(len(self._values)): - action = other_action.map(self._map_func) + action = other_action.map(self._transform) self._action_event(action) self._deltas_event(action) else: @@ -876,7 +877,7 @@ def on_action(other_action: AtIndicesDeltasAction[_T] | ClearAction[_T] | Revers elif isinstance(other_action, ReverseAction): self._reverse() - mapped_from.on_change.observe(on_action) + source.on_change.observe(on_action) def _is_observed(self) -> bool: return self._action_event.is_observed() or self._deltas_event.is_observed() diff --git a/tests/test_collections/test_mapped_observable_bag.py b/tests/test_collections/test_mapped_observable_bag.py new file mode 100644 index 0000000..906f126 --- /dev/null +++ b/tests/test_collections/test_mapped_observable_bag.py @@ -0,0 +1,461 @@ +from conftest import ValueCollectionObservers, OneParameterObserver +from spellbind.actions import clear_action, SimpleRemoveOneAction, SimpleAddOneAction, SimpleOneElementChangedAction +from spellbind.observable_sequences import ObservableList +from spellbind.observable_collections import MappedObservableBag + + +def test_initialize_empty(): + source = ObservableList() + mapped = MappedObservableBag(source, lambda x: x * 2) + assert len(mapped) == 0 + assert list(mapped) == [] + + +def test_initialize_with_items(): + source = ObservableList([1, 2, 3, 4, 5]) + mapped = MappedObservableBag(source, lambda x: x * 2) + assert len(mapped) == 5 + assert sorted(mapped) == [2, 4, 6, 8, 10] + + +def test_initialize_with_duplicates(): + source = ObservableList([1, 2, 1, 3]) + mapped = MappedObservableBag(source, lambda x: x * 2) + assert len(mapped) == 4 + assert sorted(mapped) == [2, 2, 4, 6] + + +def test_initialize_with_transform_to_string(): + source = ObservableList([1, 2, 3]) + mapped = MappedObservableBag(source, lambda x: f"item_{x}") + assert len(mapped) == 3 + assert sorted(mapped) == ["item_1", "item_2", "item_3"] + + +def test_contains_mapped_item(): + source = ObservableList([1, 2, 3, 4]) + mapped = MappedObservableBag(source, lambda x: x * 2) + assert 2 in mapped + assert 4 in mapped + assert 6 in mapped + assert 8 in mapped + + +def test_does_not_contain_unmapped_item(): + source = ObservableList([1, 2, 3, 4]) + mapped = MappedObservableBag(source, lambda x: x * 2) + assert 1 not in mapped + assert 3 not in mapped + assert 5 not in mapped + + +def test_append_item(): + source = ObservableList([1, 2, 3]) + mapped = MappedObservableBag(source, lambda x: x * 2) + observers = ValueCollectionObservers(mapped) + + source.append(4) + + assert len(mapped) == 4 + assert sorted(mapped) == [2, 4, 6, 8] + observers.assert_added_calls(8) + observers.assert_single_action(SimpleAddOneAction(8)) + + +def test_append_duplicate_item(): + source = ObservableList([1, 2, 3]) + mapped = MappedObservableBag(source, lambda x: x * 2) + observers = ValueCollectionObservers(mapped) + + source.append(2) + + assert len(mapped) == 4 + assert sorted(mapped) == [2, 4, 4, 6] + observers.assert_added_calls(4) + observers.assert_single_action(SimpleAddOneAction(4)) + + +def test_append_multiple_that_map_to_same(): + source = ObservableList([1, 2]) + mapped = MappedObservableBag(source, lambda x: x // 2) + observers = ValueCollectionObservers(mapped) + + source.append(3) + + assert len(mapped) == 3 + assert sorted(mapped) == [0, 1, 1] + observers.assert_added_calls(1) + observers.assert_single_action(SimpleAddOneAction(1)) + + +def test_remove_item(): + source = ObservableList([1, 2, 3, 4]) + mapped = MappedObservableBag(source, lambda x: x * 2) + observers = ValueCollectionObservers(mapped) + + source.remove(2) + + assert len(mapped) == 3 + assert sorted(mapped) == [2, 6, 8] + observers.assert_removed_calls(4) + observers.assert_single_action(SimpleRemoveOneAction(4)) + + +def test_remove_duplicate_item(): + source = ObservableList([1, 2, 1, 3]) + mapped = MappedObservableBag(source, lambda x: x * 2) + observers = ValueCollectionObservers(mapped) + + assert len(mapped) == 4 + assert sorted(mapped) == [2, 2, 4, 6] + + source.remove(1) + assert len(mapped) == 3 + assert 2 in mapped + + source.remove(1) + assert len(mapped) == 2 + assert 2 not in mapped + + observers.assert_removed_calls(2, 2) + observers.assert_actions(SimpleRemoveOneAction(2), SimpleRemoveOneAction(2)) + + +def test_clear_source(): + source = ObservableList([1, 2, 3, 4]) + mapped = MappedObservableBag(source, lambda x: x * 2) + observers = ValueCollectionObservers(mapped) + + source.clear() + + assert len(mapped) == 0 + assert list(mapped) == [] + observers.assert_removed_calls(2, 4, 6, 8) + observers.assert_single_action(clear_action()) + + +def test_clear_empty_source(): + source = ObservableList() + mapped = MappedObservableBag(source, lambda x: x * 2) + observers = ValueCollectionObservers(mapped) + + source.clear() + + assert len(mapped) == 0 + observers.assert_not_called() + + +def test_length_value_updates(): + source = ObservableList([1, 2, 3]) + mapped = MappedObservableBag(source, lambda x: x * 2) + + length_observer = OneParameterObserver() + mapped.length_value.observe(length_observer) + + assert mapped.length_value.value == 3 + + source.append(4) + assert mapped.length_value.value == 4 + + source.remove(2) + assert mapped.length_value.value == 3 + + assert length_observer.calls == [4, 3] + + +def test_multiple_operations(): + source = ObservableList([1, 2, 3, 4, 5]) + mapped = MappedObservableBag(source, lambda x: x * 2) + observers = ValueCollectionObservers(mapped) + + assert sorted(mapped) == [2, 4, 6, 8, 10] + + source.append(6) + assert sorted(mapped) == [2, 4, 6, 8, 10, 12] + + source.remove(2) + assert sorted(mapped) == [2, 6, 8, 10, 12] + + source.clear() + assert list(mapped) == [] + + observers.assert_calls((12, True), (4, False), (2, False), (6, False), (8, False), (10, False), (12, False)) + observers.assert_actions( + SimpleAddOneAction(12), + SimpleRemoveOneAction(4), + clear_action() + ) + + +def test_is_unobserved_initially(): + source = ObservableList([1, 2, 3]) + mapped = MappedObservableBag(source, lambda x: x * 2) + + assert not mapped.on_change.is_observed() + assert not mapped.delta_observable.is_observed() + + +def test_multiple_adds_and_removes_with_duplicates(): + source = ObservableList([1, 2, 3]) + mapped = MappedObservableBag(source, lambda x: x * 2) + observers = ValueCollectionObservers(mapped) + + assert sorted(mapped) == [2, 4, 6] + assert len(mapped) == 3 + + source.append(1) + assert len(mapped) == 4 + assert sorted(mapped) == [2, 2, 4, 6] + + source.append(4) + assert len(mapped) == 5 + assert sorted(mapped) == [2, 2, 4, 6, 8] + + source.remove(1) + assert len(mapped) == 4 + assert 2 in mapped + + source.remove(1) + assert len(mapped) == 3 + assert 2 not in mapped + assert sorted(mapped) == [4, 6, 8] + + observers.assert_calls((2, True), (8, True), (2, False), (2, False)) + observers.assert_actions( + SimpleAddOneAction(2), + SimpleAddOneAction(8), + SimpleRemoveOneAction(2), + SimpleRemoveOneAction(2) + ) + + +def test_setitem(): + source = ObservableList([1, 2, 3, 4]) + mapped = MappedObservableBag(source, lambda x: x * 2) + observers = ValueCollectionObservers(mapped) + + source[1] = 5 + + assert len(mapped) == 4 + assert sorted(mapped) == [2, 6, 8, 10] + observers.assert_calls((4, False), (10, True)) + observers.assert_single_action(SimpleOneElementChangedAction(old_item=4, new_item=10)) + + +def test_setitem_to_duplicate(): + source = ObservableList([1, 2, 3, 4]) + mapped = MappedObservableBag(source, lambda x: x * 2) + observers = ValueCollectionObservers(mapped) + + source[1] = 1 + + assert len(mapped) == 4 + assert sorted(mapped) == [2, 2, 6, 8] + observers.assert_calls((4, False), (2, True)) + observers.assert_single_action(SimpleOneElementChangedAction(old_item=4, new_item=2)) + + +def test_extend_with_items(): + source = ObservableList([1, 2]) + mapped = MappedObservableBag(source, lambda x: x * 2) + observers = ValueCollectionObservers(mapped) + + source.extend([3, 4, 5]) + + assert len(mapped) == 5 + assert sorted(mapped) == [2, 4, 6, 8, 10] + observers.assert_added_calls(6, 8, 10) + assert len(observers.on_change_observer.calls) == 1 + + +def test_extend_with_duplicates(): + source = ObservableList([1, 2]) + mapped = MappedObservableBag(source, lambda x: x * 2) + observers = ValueCollectionObservers(mapped) + + source.extend([1, 3, 1, 4]) + + assert len(mapped) == 6 + assert sorted(mapped) == [2, 2, 2, 4, 6, 8] + observers.assert_added_calls(2, 6, 2, 8) + assert len(observers.on_change_observer.calls) == 1 + + +def test_insert_item(): + source = ObservableList([1, 3, 5]) + mapped = MappedObservableBag(source, lambda x: x * 2) + observers = ValueCollectionObservers(mapped) + + source.insert(1, 2) + + assert len(mapped) == 4 + assert sorted(mapped) == [2, 4, 6, 10] + observers.assert_added_calls(4) + observers.assert_single_action(SimpleAddOneAction(4)) + + +def test_insert_duplicate_item(): + source = ObservableList([1, 2, 3]) + mapped = MappedObservableBag(source, lambda x: x * 2) + observers = ValueCollectionObservers(mapped) + + source.insert(1, 1) + + assert len(mapped) == 4 + assert sorted(mapped) == [2, 2, 4, 6] + observers.assert_added_calls(2) + observers.assert_single_action(SimpleAddOneAction(2)) + + +def test_del_by_index(): + source = ObservableList([1, 2, 3, 4]) + mapped = MappedObservableBag(source, lambda x: x * 2) + observers = ValueCollectionObservers(mapped) + + del source[1] + + assert len(mapped) == 3 + assert sorted(mapped) == [2, 6, 8] + observers.assert_removed_calls(4) + observers.assert_single_action(SimpleRemoveOneAction(4)) + + +def test_del_by_index_with_duplicates(): + source = ObservableList([1, 2, 1, 3]) + mapped = MappedObservableBag(source, lambda x: x * 2) + observers = ValueCollectionObservers(mapped) + + del source[0] + + assert len(mapped) == 3 + assert sorted(mapped) == [2, 4, 6] + observers.assert_removed_calls(2) + observers.assert_single_action(SimpleRemoveOneAction(2)) + + +def test_del_by_slice(): + source = ObservableList([1, 2, 3, 4, 5, 6]) + mapped = MappedObservableBag(source, lambda x: x * 2) + observers = ValueCollectionObservers(mapped) + + del source[1:4] + + assert len(mapped) == 3 + assert sorted(mapped) == [2, 10, 12] + observers.assert_removed_calls(4, 6, 8) + assert len(observers.on_change_observer.calls) == 1 + + +def test_del_by_slice_with_duplicates(): + source = ObservableList([1, 2, 1, 3, 1, 4]) + mapped = MappedObservableBag(source, lambda x: x * 2) + observers = ValueCollectionObservers(mapped) + + del source[1:4] + + assert len(mapped) == 3 + assert sorted(mapped) == [2, 2, 8] + observers.assert_removed_calls(4, 2, 6) + assert len(observers.on_change_observer.calls) == 1 + + +def test_pop(): + source = ObservableList([1, 2, 3, 4]) + mapped = MappedObservableBag(source, lambda x: x * 2) + observers = ValueCollectionObservers(mapped) + + popped = source.pop(2) + + assert popped == 3 + assert len(mapped) == 3 + assert sorted(mapped) == [2, 4, 8] + observers.assert_removed_calls(6) + observers.assert_single_action(SimpleRemoveOneAction(6)) + + +def test_pop_default(): + source = ObservableList([1, 2, 3, 4]) + mapped = MappedObservableBag(source, lambda x: x * 2) + observers = ValueCollectionObservers(mapped) + + popped = source.pop() + + assert popped == 4 + assert len(mapped) == 3 + assert sorted(mapped) == [2, 4, 6] + observers.assert_removed_calls(8) + observers.assert_single_action(SimpleRemoveOneAction(8)) + + +def test_pop_with_duplicates(): + source = ObservableList([1, 2, 1, 3]) + mapped = MappedObservableBag(source, lambda x: x * 2) + observers = ValueCollectionObservers(mapped) + + popped = source.pop(0) + + assert popped == 1 + assert len(mapped) == 3 + assert 2 in mapped + assert sorted(mapped) == [2, 4, 6] + observers.assert_removed_calls(2) + observers.assert_single_action(SimpleRemoveOneAction(2)) + + +def test_slice_assignment(): + source = ObservableList([1, 2, 3, 4, 5]) + mapped = MappedObservableBag(source, lambda x: x * 2) + observers = ValueCollectionObservers(mapped) + + source[1:3] = [10, 20] + + assert len(mapped) == 5 + assert sorted(mapped) == [2, 8, 10, 20, 40] + assert len(observers.delta_observer.calls) == 4 + assert len(observers.on_change_observer.calls) == 1 + + +def test_slice_assignment_with_duplicates(): + source = ObservableList([1, 2, 3, 1, 4]) + mapped = MappedObservableBag(source, lambda x: x * 2) + observers = ValueCollectionObservers(mapped) + + source[1:3] = [1, 1] + + assert len(mapped) == 5 + assert sorted(mapped) == [2, 2, 2, 2, 8] + assert len(observers.delta_observer.calls) == 4 + assert len(observers.on_change_observer.calls) == 1 + + +def test_reverse(): + source = ObservableList([1, 2, 3, 4]) + mapped = MappedObservableBag(source, lambda x: x * 2) + observers = ValueCollectionObservers(mapped) + + source.reverse() + + assert sorted(mapped) == [2, 4, 6, 8] + observers.assert_not_called() + + +def test_transform_produces_same_values(): + source = ObservableList([1, 2, 3, 4, 5]) + mapped = MappedObservableBag(source, lambda x: x // 2) + + assert len(mapped) == 5 + assert sorted(mapped) == [0, 1, 1, 2, 2] + + observers = ValueCollectionObservers(mapped) + + source.remove(3) + assert sorted(mapped) == [0, 1, 2, 2] + observers.assert_removed_calls(1) + + +def test_repr(): + source = ObservableList([1, 2, 3]) + mapped = MappedObservableBag(source, lambda x: x * 2) + + repr_str = repr(mapped) + assert "MappedObservableBag" in repr_str diff --git a/tests/test_collections/test_observable_float_lists/test_observable_float_list.py b/tests/test_collections/test_observable_float_lists/test_observable_float_list.py new file mode 100644 index 0000000..8c8b7d8 --- /dev/null +++ b/tests/test_collections/test_observable_float_lists/test_observable_float_list.py @@ -0,0 +1,121 @@ +from conftest import OneParameterObserver +from spellbind.float_collections import ObservableFloatList +from spellbind.observable_sequences import ObservableList + + +def test_combine_floats(): + float_list = ObservableFloatList([1.0, 2.0, 3.0]) + combined = float_list.combine_to_float(combiner=sum) + assert combined.value == 6.0 + + +def test_derive_commutative_reverse_not_called(): + float_list = ObservableFloatList([1.0, 2.0, 3.0]) + calls = [] + + def reducer(x, y): + calls.append("added") + return 0.0 + + summed = float_list.reduce(add_reducer=reducer, remove_reducer=reducer, initial=1.0) + calls.clear() + float_list.reverse() + assert calls == [] + + +def test_derive_commutative_reduce_order_not_called(): + float_list = ObservableFloatList([1.0, 2.0, 3.0]) + calls = [] + + def add_reducer(x, y): + calls.append(f"added {y}") + return 0.0 + + def remove_reducer(x, y): + calls.append(f"removed {y}") + return 0.0 + summed = float_list.reduce(add_reducer=add_reducer, remove_reducer=remove_reducer, initial=1.0) + calls.clear() + float_list.append(1.0) + float_list.append(2.0) + float_list.append(3.0) + float_list.pop(1) + assert calls == ["added 1.0", "added 2.0", "added 3.0", "removed 2.0"] + + +def test_reduce_to_float_half_string_lengths(): + string_list = ObservableList(["a", "bb", "ccc"]) + half_length = string_list.reduce_to_float( + add_reducer=lambda acc, s: acc + len(s) / 2.0, + remove_reducer=lambda acc, s: acc - len(s) / 2.0, + initial=0.0 + ) + assert half_length.value == 3.0 + + +def test_reduce_to_float_half_string_lengths_append(): + string_list = ObservableList(["a", "bb", "ccc"]) + half_length = string_list.reduce_to_float( + add_reducer=lambda acc, s: acc + len(s) / 2.0, + remove_reducer=lambda acc, s: acc - len(s) / 2.0, + initial=0.0 + ) + observer = OneParameterObserver() + half_length.observe(observer) + + assert half_length.value == 3.0 + + string_list.append("dddd") + assert half_length.value == 5.0 + observer.assert_called_once_with(5.0) + + +def test_reduce_to_float_half_string_lengths_remove(): + string_list = ObservableList(["a", "bb", "ccc"]) + half_length = string_list.reduce_to_float( + add_reducer=lambda acc, s: acc + len(s) / 2.0, + remove_reducer=lambda acc, s: acc - len(s) / 2.0, + initial=0.0 + ) + observer = OneParameterObserver() + half_length.observe(observer) + + assert half_length.value == 3.0 + + string_list.remove("bb") + assert half_length.value == 2.0 + observer.assert_called_once_with(2.0) + + +def test_reduce_to_float_half_string_lengths_setitem(): + string_list = ObservableList(["a", "bb", "ccc"]) + half_length = string_list.reduce_to_float( + add_reducer=lambda acc, s: acc + len(s) / 2.0, + remove_reducer=lambda acc, s: acc - len(s) / 2.0, + initial=0.0 + ) + observer = OneParameterObserver() + half_length.observe(observer) + + assert half_length.value == 3.0 + + string_list[1] = "dddd" + assert half_length.value == 4.0 + observer.assert_called_once_with(4.0) + + +def test_reduce_to_float_half_string_lengths_reverse(): + string_list = ObservableList(["a", "bb", "ccc"]) + half_length = string_list.reduce_to_float( + add_reducer=lambda acc, s: acc + len(s) / 2.0, + remove_reducer=lambda acc, s: acc - len(s) / 2.0, + initial=0.0 + ) + observer = OneParameterObserver() + half_length.observe(observer) + + assert half_length.value == 3.0 + + string_list.reverse() + assert half_length.value == 3.0 + observer.assert_not_called() diff --git a/tests/test_collections/test_observable_float_lists/test_observable_float_list_sum_and_modify_list.py b/tests/test_collections/test_observable_float_lists/test_observable_float_list_sum_and_modify_list.py new file mode 100644 index 0000000..788ed37 --- /dev/null +++ b/tests/test_collections/test_observable_float_lists/test_observable_float_list_sum_and_modify_list.py @@ -0,0 +1,133 @@ +import pytest + +from conftest import OneParameterObserver +from spellbind.float_collections import ObservableFloatList, FloatValueList + + +@pytest.mark.parametrize("constructor", [ObservableFloatList, FloatValueList]) +def test_sum_float_list_append_sequentially(constructor): + float_list = constructor([1.0, 2.0, 3.0]) + summed = float_list.summed + observer = OneParameterObserver() + summed.observe(observer) + assert summed.value == 6.0 + float_list.append(4.0) + assert summed.value == 10.0 + float_list.append(5.0) + assert summed.value == 15.0 + float_list.append(6.0) + assert summed.value == 21.0 + assert observer.calls == [10.0, 15.0, 21.0] + + +def test_sum_float_list_clear(): + float_list = ObservableFloatList([1.0, 2.0, 3.0]) + summed = float_list.summed + observer = OneParameterObserver() + summed.observe(observer) + assert summed.value == 6.0 + float_list.clear() + assert summed.value == 0.0 + observer.assert_called_once_with(0.0) + + +def test_sum_float_list_del_sequentially(): + float_list = ObservableFloatList([1.0, 2.0, 3.0]) + summed = float_list.summed + observer = OneParameterObserver() + summed.observe(observer) + assert summed.value == 6.0 + del float_list[0] + assert summed.value == 5.0 + del float_list[0] + assert summed.value == 3.0 + del float_list[0] + assert summed.value == 0.0 + assert observer.calls == [5.0, 3.0, 0.0] + + +def test_sum_float_list_del_slice(): + float_list = ObservableFloatList([1.0, 2.0, 3.0, 4.0, 5.0]) + summed = float_list.summed + observer = OneParameterObserver() + summed.observe(observer) + assert summed.value == 15.0 + del float_list[1:4] + assert summed.value == 6.0 + assert observer.calls == [6.0] + + +def test_sum_float_list_extend(): + float_list = ObservableFloatList([1.0, 2.0, 3.0]) + summed = float_list.summed + observer = OneParameterObserver() + summed.observe(observer) + assert summed.value == 6.0 + float_list.extend([4.0, 5.0]) + assert summed.value == 15.0 + float_list.extend([6.0]) + assert summed.value == 21.0 + assert observer.calls == [15.0, 21.0] + + +def test_sum_float_list_insert(): + float_list = ObservableFloatList([1.0, 2.0, 3.0]) + summed = float_list.summed + observer = OneParameterObserver() + summed.observe(observer) + assert summed.value == 6.0 + float_list.insert(0, 4.0) + assert summed.value == 10.0 + float_list.insert(2, 5.0) + assert summed.value == 15.0 + float_list.insert(5, 6.0) + assert summed.value == 21.0 + assert observer.calls == [10.0, 15.0, 21.0] + + +def test_sum_float_list_insert_all(): + float_list = ObservableFloatList([1.0, 2.0, 3.0]) + summed = float_list.summed + observer = OneParameterObserver() + summed.observe(observer) + assert summed.value == 6.0 + float_list.insert_all(((1, 4.0), (2, 5.0), (3, 6.0))) + assert summed.value == 21.0 + assert observer.calls == [21.0] + + +def test_sum_float_list_setitem(): + float_list = ObservableFloatList([1.0, 2.0, 3.0]) + summed = float_list.summed + observer = OneParameterObserver() + summed.observe(observer) + assert summed.value == 6.0 + float_list[0] = 4.0 + assert summed.value == 9.0 + float_list[1] = 5.0 + assert summed.value == 12.0 + float_list[2] = 6.0 + assert summed.value == 15.0 + assert observer.calls == [9.0, 12.0, 15.0] + + +def test_sum_float_list_set_slice(): + float_list = ObservableFloatList([1.0, 2.0, 3.0]) + summed = float_list.summed + observer = OneParameterObserver() + summed.observe(observer) + assert summed.value == 6.0 + float_list[0:3] = [4.0, 5.0, 6.0] + assert summed.value == 15.0 + assert observer.calls == [15.0] + + +def test_sum_float_list_reverse(): + float_list = ObservableFloatList([1.0, 2.0, 3.0]) + summed = float_list.summed + observer = OneParameterObserver() + summed.observe(observer) + assert summed.value == 6.0 + float_list.reverse() + assert summed.value == 6.0 + observer.assert_not_called() diff --git a/tests/test_collections/test_observable_int_lists/test_observable_int_list.py b/tests/test_collections/test_observable_int_lists/test_observable_int_list.py index f7a1d37..acd1fdf 100644 --- a/tests/test_collections/test_observable_int_lists/test_observable_int_list.py +++ b/tests/test_collections/test_observable_int_lists/test_observable_int_list.py @@ -1,4 +1,6 @@ +from conftest import OneParameterObserver from spellbind.int_collections import ObservableIntList +from spellbind.observable_sequences import ObservableList def test_combine_ints(): @@ -39,3 +41,81 @@ def remove_reducer(x, y): int_list.append(3) int_list.pop(1) assert calls == ["added 1", "added 2", "added 3", "removed 2"] + + +def test_reduce_to_int_string_lengths(): + string_list = ObservableList(["a", "bb", "ccc"]) + total_length = string_list.reduce_to_int( + add_reducer=lambda acc, s: acc + len(s), + remove_reducer=lambda acc, s: acc - len(s), + initial=0 + ) + assert total_length.value == 6 + + +def test_reduce_to_int_string_lengths_append(): + string_list = ObservableList(["a", "bb", "ccc"]) + total_length = string_list.reduce_to_int( + add_reducer=lambda acc, s: acc + len(s), + remove_reducer=lambda acc, s: acc - len(s), + initial=0 + ) + observer = OneParameterObserver() + total_length.observe(observer) + + assert total_length.value == 6 + + string_list.append("dddd") + assert total_length.value == 10 + observer.assert_called_once_with(10) + + +def test_reduce_to_int_string_lengths_remove(): + string_list = ObservableList(["a", "bb", "ccc"]) + total_length = string_list.reduce_to_int( + add_reducer=lambda acc, s: acc + len(s), + remove_reducer=lambda acc, s: acc - len(s), + initial=0 + ) + observer = OneParameterObserver() + total_length.observe(observer) + + assert total_length.value == 6 + + string_list.remove("bb") + assert total_length.value == 4 + observer.assert_called_once_with(4) + + +def test_reduce_to_int_string_lengths_setitem(): + string_list = ObservableList(["a", "bb", "ccc"]) + total_length = string_list.reduce_to_int( + add_reducer=lambda acc, s: acc + len(s), + remove_reducer=lambda acc, s: acc - len(s), + initial=0 + ) + observer = OneParameterObserver() + total_length.observe(observer) + + assert total_length.value == 6 + + string_list[1] = "dddd" + assert total_length.value == 8 + observer.assert_called_once_with(8) + + +def test_reduce_to_int_string_lengths_reverse(): + string_list = ObservableList(["a", "bb", "ccc"]) + total_length = string_list.reduce_to_int( + add_reducer=lambda acc, s: acc + len(s), + remove_reducer=lambda acc, s: acc - len(s), + initial=0 + ) + observer = OneParameterObserver() + total_length.observe(observer) + + assert total_length.value == 6 + + string_list.reverse() + assert total_length.value == 6 + observer.assert_not_called()