Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions src/spellbind/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
from spellbind.emitters import Emitter, TriEmitter, BiEmitter, ValueEmitter
from spellbind.functions import assert_parameter_max_count
from spellbind.observables import Observable, ValueObservable, BiObservable, TriObservable, Observer, \
ValueObserver, BiObserver, TriObserver, Subscription, DeadReferenceError, WeakSubscription, StrongSubscription
ValueObserver, BiObserver, TriObserver, Subscription, WeakSubscription, StrongSubscription, \
RemoveSubscriptionError

_S = TypeVar("_S")
_T = TypeVar("_T")
_U = TypeVar("_U")
_O = TypeVar('_O', bound=Callable)


class BaseEvent(Generic[_O], ABC):
class _BaseEvent(Generic[_O], ABC):
_subscriptions: list[Subscription[_O]]

def __init__(self):
Expand All @@ -22,13 +23,13 @@ def __init__(self):
def _get_parameter_count(self) -> int:
raise NotImplementedError

def observe(self, observer: _O) -> None:
def observe(self, observer: _O, times: int | None = None) -> None:
assert_parameter_max_count(observer, self._get_parameter_count())
self._subscriptions.append(StrongSubscription(observer))
self._subscriptions.append(StrongSubscription(observer, times))

def weak_observe(self, observer: _O) -> None:
def weak_observe(self, observer: _O, times: int | None = None) -> None:
assert_parameter_max_count(observer, self._get_parameter_count())
self._subscriptions.append(WeakSubscription(observer))
self._subscriptions.append(WeakSubscription(observer, times))

def unobserve(self, observer: _O) -> None:
for i, sub in enumerate(self._subscriptions):
Expand All @@ -46,35 +47,35 @@ def _emit(self, *args) -> None:
try:
self._subscriptions[i](*args)
i += 1
except DeadReferenceError:
except RemoveSubscriptionError:
del self._subscriptions[i]


class Event(BaseEvent[Observer], Observable, Emitter):
class Event(_BaseEvent[Observer], Observable, Emitter):
def _get_parameter_count(self) -> int:
return 0

def __call__(self) -> None:
self._emit()


class ValueEvent(Generic[_S], BaseEvent[Observer | ValueObserver[_S]], ValueObservable[_S], ValueEmitter[_S]):
class ValueEvent(Generic[_S], _BaseEvent[Observer | ValueObserver[_S]], ValueObservable[_S], ValueEmitter[_S]):
def _get_parameter_count(self) -> int:
return 1

def __call__(self, value: _S) -> None:
self._emit(value)


class BiEvent(Generic[_S, _T], BaseEvent[Observer | ValueObserver[_S] | BiObserver[_S, _T]], BiObservable[_S, _T], BiEmitter[_S, _T]):
class BiEvent(Generic[_S, _T], _BaseEvent[Observer | ValueObserver[_S] | BiObserver[_S, _T]], BiObservable[_S, _T], BiEmitter[_S, _T]):
def _get_parameter_count(self) -> int:
return 2

def __call__(self, value_0: _S, value_1: _T) -> None:
self._emit(value_0, value_1)


class TriEvent(Generic[_S, _T, _U], BaseEvent[Observer | ValueObserver[_S] | BiObserver[_S, _T] | TriObserver[_S, _T, _U]], TriObservable[_S, _T, _U], TriEmitter[_S, _T, _U]):
class TriEvent(Generic[_S, _T, _U], _BaseEvent[Observer | ValueObserver[_S] | BiObserver[_S, _T] | TriObserver[_S, _T, _U]], TriObservable[_S, _T, _U], TriEmitter[_S, _T, _U]):
def _get_parameter_count(self) -> int:
return 3

Expand Down
29 changes: 21 additions & 8 deletions src/spellbind/observables.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,30 @@ class TriObserver(Protocol[_SC, _TC, _UC]):
def __call__(self, arg1: _SC, arg2: _TC, arg3: _UC, /) -> None: ...


class DeadReferenceError(Exception):
class RemoveSubscriptionError(Exception):
pass


class CallCountExceededError(RemoveSubscriptionError):
pass


class DeadReferenceError(RemoveSubscriptionError):
pass


class Subscription(Generic[_O], ABC):
def __init__(self, observer: _O):
def __init__(self, observer: _O, times: int | None):
self._positional_parameter_count = count_positional_parameters(observer)
self.called_counter = 0
self.max_call_count = times

def _call(self, observer: _O, *args) -> None:
self.called_counter += 1
trimmed_args = args[:self._positional_parameter_count]
observer(*trimmed_args)
if self.max_call_count is not None and self.called_counter >= self.max_call_count:
raise CallCountExceededError

@abstractmethod
def __call__(self, *args) -> None:
Expand All @@ -53,8 +66,8 @@ def matches_observer(self, observer: _O) -> bool:


class StrongSubscription(Subscription[_O], Generic[_O]):
def __init__(self, observer: _O):
super().__init__(observer)
def __init__(self, observer: _O, times: int | None):
super().__init__(observer, times)
self._observer = observer

def __call__(self, *args) -> None:
Expand All @@ -67,8 +80,8 @@ def matches_observer(self, observer: _O) -> bool:
class WeakSubscription(Subscription[_O], Generic[_O]):
_ref: ref[_O] | WeakMethod

def __init__(self, observer: _O):
super().__init__(observer)
def __init__(self, observer: _O, times: int | None):
super().__init__(observer, times)
if hasattr(observer, '__self__'):
self._ref = WeakMethod(observer)
else:
Expand All @@ -86,11 +99,11 @@ def matches_observer(self, observer: _O) -> bool:

class Observable(ABC):
@abstractmethod
def observe(self, observer: Observer) -> None:
def observe(self, observer: Observer, times: int | None = None) -> None:
raise NotImplementedError

@abstractmethod
def weak_observe(self, observer: Observer) -> None:
def weak_observe(self, observer: Observer, times: int | None = None) -> None:
raise NotImplementedError

@abstractmethod
Expand Down
Loading