Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions src/spellbind/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def is_permutation_only(self) -> bool: ...
@abstractmethod
def map(self, transformer: Callable[[_S_co], _T]) -> CollectionAction[_T]: ...

@abstractmethod
def filter(self, predicate: Callable[[_S_co], bool]) -> CollectionAction[_S_co] | None: ...

@override
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
Expand All @@ -34,6 +37,10 @@ def is_permutation_only(self) -> bool:
def map(self, transformer: Callable[[_S_co], _T]) -> ClearAction[_T]:
return clear_action()

@override
def filter(self, predicate: Callable[[_S_co], bool]) -> ClearAction[_S_co]:
return self


class SingleValueAction(CollectionAction[_S_co], Generic[_S_co]):
@property
Expand All @@ -56,6 +63,13 @@ def map(self, transformer: Callable[[_S_co], _T]) -> DeltasAction[_T]:
mapped = tuple(action.map(transformer) for action in self.delta_actions)
return SimpleDeltasAction(mapped)

@override
def filter(self, predicate: Callable[[_S_co], bool]) -> DeltasAction[_S_co] | None:
filtered_actions = tuple(action for action in self.delta_actions if predicate(action.value))
if not filtered_actions:
return None
return SimpleDeltasAction(filtered_actions)


class SimpleDeltasAction(DeltasAction[_S_co], Generic[_S_co]):
def __init__(self, delta_actions: tuple[DeltaAction[_S_co], ...]):
Expand Down Expand Up @@ -86,6 +100,10 @@ def delta_actions(self) -> tuple[DeltaAction[_S_co], ...]:
@override
def map(self, transformer: Callable[[_S_co], _T]) -> DeltaAction[_T]: ...

@abstractmethod
@override
def filter(self, predicate: Callable[[_S_co], bool]) -> DeltaAction[_S_co] | None: ...


class AddOneAction(DeltaAction[_S_co], Generic[_S_co], ABC):
@property
Expand All @@ -97,6 +115,12 @@ def is_add(self) -> bool:
def map(self, transformer: Callable[[_S_co], _T]) -> AddOneAction[_T]:
return SimpleAddOneAction(transformer(self.value))

@override
def filter(self, predicate: Callable[[_S_co], bool]) -> AddOneAction[_S_co] | None:
if predicate(self.value):
return SimpleAddOneAction(self.value)
return None

@override
def __repr__(self) -> str:
return f"{self.__class__.__name__}(value={self.value})"
Expand All @@ -111,6 +135,12 @@ def __init__(self, item: _S_co) -> None:
def value(self) -> _S_co:
return self._item

@override
def __eq__(self, other: object) -> bool:
if not isinstance(other, AddOneAction):
return NotImplemented
return bool(self.value == other.value)


class RemoveOneAction(DeltaAction[_S_co], Generic[_S_co], ABC):
@property
Expand All @@ -122,6 +152,12 @@ def is_add(self) -> bool:
def map(self, transformer: Callable[[_S_co], _T]) -> RemoveOneAction[_T]:
return SimpleRemoveOneAction(transformer(self.value))

@override
def filter(self, predicate: Callable[[_S_co], bool]) -> RemoveOneAction[_S_co] | None:
if predicate(self.value):
return SimpleRemoveOneAction(self.value)
return None

@override
def __repr__(self) -> str:
return f"{self.__class__.__name__}(value={self.value})"
Expand All @@ -136,6 +172,12 @@ def __init__(self, item: _S_co) -> None:
def value(self) -> _S_co:
return self._item

@override
def __eq__(self, other: object) -> bool:
if not isinstance(other, RemoveOneAction):
return NotImplemented
return bool(self.value == other.value)


class ElementsChangedAction(DeltasAction[_S_co], Generic[_S_co], ABC):
@property
Expand Down Expand Up @@ -195,6 +237,19 @@ def delta_actions(self) -> tuple[DeltaAction[_S_co], ...]:
def map(self, transformer: Callable[[_S_co], _T]) -> OneElementChangedAction[_T]:
return SimpleOneElementChangedAction(new_item=transformer(self.new_item), old_item=transformer(self.old_item))

@override
def filter(self, predicate: Callable[[_S_co], bool]) -> DeltasAction[_S_co] | None:
old_matches = predicate(self.old_item)
new_matches = predicate(self.new_item)
if old_matches and new_matches:
return SimpleOneElementChangedAction(new_item=self.new_item, old_item=self.old_item)
elif old_matches:
return SimpleRemoveOneAction(self.old_item)
elif new_matches:
return SimpleAddOneAction(self.new_item)
else:
return None


class SimpleOneElementChangedAction(OneElementChangedAction[_S_co], Generic[_S_co]):
def __init__(self, *, new_item: _S_co, old_item: _S_co):
Expand Down Expand Up @@ -614,6 +669,10 @@ def is_permutation_only(self) -> bool:
def map(self, transformer: Callable[[_S_co], _T]) -> ReverseAction[_T]:
return reverse_action()

@override
def filter(self, predicate: Callable[[_S_co], bool]) -> ReverseAction[_S_co]:
return self


class ExtendAction(AtIndicesDeltasAction[_S_co], SequenceAction[_S_co], Generic[_S_co], ABC):
@property
Expand Down
107 changes: 104 additions & 3 deletions src/spellbind/observable_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@

from typing_extensions import override

from spellbind.actions import CollectionAction, DeltaAction, DeltasAction, ClearAction
from spellbind.actions import CollectionAction, DeltaAction, DeltasAction, ClearAction, ReverseAction, clear_action
from spellbind.bool_values import BoolValue
from spellbind.deriveds import Derived
from spellbind.event import BiEvent
from spellbind.int_values import IntValue
from spellbind.event import BiEvent, ValueEvent
from spellbind.int_values import IntValue, IntVariable
from spellbind.observables import ValuesObservable, ValueObservable, Observer, ValueObserver, BiObserver, \
Subscription
from spellbind.str_values import StrValue
Expand Down Expand Up @@ -86,6 +86,9 @@ def reduce_to_int(self,
remove_reducer=remove_reducer,
initial=initial)

def filter_to_bag(self, predicate: Callable[[_S_co], bool]) -> FilteredObservableBag[_S_co]:
return FilteredObservableBag(self, predicate)


class ReducedValue(Value[_S], Generic[_S]):
def __init__(self,
Expand Down Expand Up @@ -213,3 +216,101 @@ def unobserve(self, observer: Observer | ValueObserver[_S] | BiObserver[_S, _S])
@override
def is_observed(self, by: Callable[..., Any] | None = None) -> bool:
return self._on_change.is_observed(by=by)


class FilteredObservableBag(ObservableCollection[_S], Generic[_S]):
def __init__(self, source: ObservableCollection[_S], predicate: Callable[[_S], bool]) -> None:
self._source = source
self._predicate = predicate
self._item_counts: dict[_S, int] = {}

total_count = 0
for item in source:
if predicate(item):
self._item_counts[item] = self._item_counts.get(item, 0) + 1
total_count += 1

self._on_change = ValueEvent[CollectionAction[_S]]()
self._delta_event = ValueEvent[DeltasAction[_S]]()
self._delta_observable = self._delta_event.map_to_values_observable(
transformer=lambda deltas_action: deltas_action.delta_actions
)
self._len_value = IntVariable(total_count)

source.on_change.observe(self._on_source_action)

def _on_source_action(self, action: CollectionAction[_S]) -> None:
if isinstance(action, ClearAction):
self._clear()
elif isinstance(action, ReverseAction):
pass
elif isinstance(action, DeltasAction):
filtered_action = action.filter(self._predicate)
if filtered_action is not None:
total_count = self._len_value.value
for delta in filtered_action.delta_actions:
if delta.is_add:
self._item_counts[delta.value] = self._item_counts.get(delta.value, 0) + 1
total_count += 1
else:
count = self._item_counts.get(delta.value, 0)
if count > 0:
if count == 1:
del self._item_counts[delta.value]
else:
self._item_counts[delta.value] = count - 1
total_count -= 1

if self._is_observed():
with self._len_value.set_delay_notify(total_count):
self._on_change(filtered_action)
self._delta_event(filtered_action)
else:
self._len_value.value = total_count

def _clear(self) -> None:
if self._len_value.value == 0:
return
self._item_counts.clear()
if self._is_observed():
with self._len_value.set_delay_notify(0):
action: ClearAction[_S] = clear_action()
self._on_change(action)
else:
self._len_value.value = 0

def _is_observed(self) -> bool:
return self._on_change.is_observed() or self._delta_event.is_observed()

@property
@override
def on_change(self) -> ValueObservable[CollectionAction[_S]]:
return self._on_change

@property
@override
def delta_observable(self) -> ValuesObservable[DeltaAction[_S]]:
return self._delta_observable

@property
@override
def length_value(self) -> IntValue:
return self._len_value

@override
def __len__(self) -> int:
return self._len_value.value

@override
def __contains__(self, item: object) -> bool:
return item in self._item_counts

@override
def __iter__(self) -> Iterator[_S]:
for item, count in self._item_counts.items():
for _ in range(count):
yield item

@override
def __repr__(self) -> str:
return f"{self.__class__.__name__}({list(self)!r})"
2 changes: 1 addition & 1 deletion src/spellbind/observable_sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,7 @@ def on_action(other_action: AtIndicesDeltasAction[_T] | ClearAction[_T] | Revers
else:
for delta in other_action.delta_actions:
if delta.is_add:
value = self._map_func(delta.value)
value: _S = self._map_func(delta.value)
self._values.insert(delta.index, value)
else:
del self._values[delta.index]
Expand Down
35 changes: 35 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,41 @@ def assert_single_action(self, action: CollectionAction):
self.on_change_observer.assert_called_once_with(action)


class ValueCollectionObservers(Observers):
def __init__(self, collection: ObservableCollection):
self.on_change_observer = OneParameterObserver()
self.delta_observer = OneParameterObserver()
collection.on_change.observe(self.on_change_observer)
collection.delta_observable.observe_single(self.delta_observer)
super().__init__(self.on_change_observer, self.delta_observer)

def assert_added_calls(self, *expected_adds: Any):
self.assert_calls(*(append_bool(add, True) for add in expected_adds))

def assert_removed_calls(self, *expected_removes: Any):
self.assert_calls(*(append_bool(remove, False) for remove in expected_removes))

def assert_calls(self, *expected_calls: tuple[Any, bool]):
delta_calls = self.delta_observer.calls
if not len(delta_calls) == len(expected_calls):
pytest.fail(f"Expected {len(expected_calls)} calls, got {len(delta_calls)}")
for i, (call, expected_call) in enumerate(zip(delta_calls, expected_calls)):
action = call.get_arg()
assert isinstance(action, DeltaAction)
expected_value, expected_added = expected_call
if not action.is_add == expected_added:
pytest.fail(f"Error call {i}. Expected {'add' if expected_added else 'remove'}, got {'add' if action.is_add else 'remove'}")
if not action.value == expected_value:
pytest.fail(f"Error call {i}. Expected value {expected_value}, got {action.value}")

def assert_actions(self, *actions: CollectionAction):
assert self.on_change_observer.calls == [*actions]

def assert_single_action(self, action: CollectionAction):
assert len(self.on_change_observer.calls) == 1
assert self.on_change_observer.calls[0] == action


@contextmanager
def assert_length_changed_during_action_events_but_notifies_after(collection: ObservableCollection, expected_length: int):
events = []
Expand Down
Loading