diff --git a/.flake8 b/.flake8 index 05f0f21..fb60d07 100644 --- a/.flake8 +++ b/.flake8 @@ -9,4 +9,4 @@ exclude = # Allow assigning lambdas in tests per-file-ignores = - tests/*:E731 \ No newline at end of file + tests/*:E731,F841 \ No newline at end of file diff --git a/src/spellbind/float_values.py b/src/spellbind/float_values.py index daef0c9..42935af 100644 --- a/src/spellbind/float_values.py +++ b/src/spellbind/float_values.py @@ -133,20 +133,22 @@ def _get_float(value: float | Value[int] | Value[float]) -> float: class CombinedFloatValues(DerivedValueBase[_U], Generic[_U], ABC): def __init__(self, *values: float | Value[int] | Value[float]): super().__init__(*[v for v in values if isinstance(v, Value)]) - self.gotten_values = [_get_float(v) for v in values] + self._gotten_values = [_get_float(v) for v in values] + self._callbacks: list[Callable] = [] for i, v in enumerate(values): if isinstance(v, Value): - v.observe(self._create_on_n_changed(i)) + v.weak_observe(self._create_on_n_changed(i)) self._value = self._calculate_value() def _create_on_n_changed(self, index: int) -> Callable[[float], None]: def on_change(new_value: float) -> None: - self.gotten_values[index] = new_value + self._gotten_values[index] = new_value self._on_result_change(self._calculate_value()) + self._callbacks.append(on_change) # keep strong reference to callback so it won't be garbage collected return on_change def _calculate_value(self) -> _U: - return self.transform(self.gotten_values) + return self.transform(self._gotten_values) def _on_result_change(self, new_value: _U) -> None: if new_value != self._value: diff --git a/src/spellbind/values.py b/src/spellbind/values.py index 76d81d8..60c295f 100644 --- a/src/spellbind/values.py +++ b/src/spellbind/values.py @@ -215,7 +215,7 @@ class DerivedValue(DerivedValueBase[_T], Generic[_S, _T], ABC): def __init__(self, of: Value[_S]): super().__init__(of) self._value = self.transform(of.value) - of.observe(self._on_source_change) + of.weak_observe(self._on_source_change) @abstractmethod def transform(self, value: _S) -> _T: @@ -257,9 +257,9 @@ def __init__(self, left: Value[_S] | _S, right: Value[_T] | _T): self._left_getter = _create_value_getter(left) self._right_getter = _create_value_getter(right) if isinstance(left, Value): - left.observe(self._on_left_change) + left.weak_observe(self._on_left_change) if isinstance(right, Value): - right.observe(self._on_right_change) + right.weak_observe(self._on_right_change) self._value = self.transform(self._left_getter(), self._right_getter()) def _on_left_change(self, new_left_value: _S) -> None: @@ -295,15 +295,17 @@ class CombinedMixedValues(DerivedValueBase[_T], Generic[_S, _T], ABC): def __init__(self, *sources: Value[_S] | _S): super().__init__(*[v for v in sources if isinstance(v, Value)]) self.gotten_values = [_get_value(v) for v in sources] + self._callbacks: list[Callable] = [] for i, v in enumerate(sources): if isinstance(v, Value): - v.observe(self._create_on_n_changed(i)) + v.weak_observe(self._create_on_n_changed(i)) self._value = self._calculate_value() def _create_on_n_changed(self, index: int) -> Callable[[_S], None]: def on_change(new_value: _S) -> None: self.gotten_values[index] = new_value self._on_result_change(self._calculate_value()) + self._callbacks.append(on_change) # keep strong reference to callback so it won't be garbage collected return on_change def _calculate_value(self) -> _T: diff --git a/tests/test_values/test_float_values/test_float_values.py b/tests/test_values/test_float_values/test_float_values.py index ebe4e63..6e8b9fb 100644 --- a/tests/test_values/test_float_values/test_float_values.py +++ b/tests/test_values/test_float_values/test_float_values.py @@ -1,4 +1,6 @@ -from spellbind.float_values import FloatConstant, MaxFloatValues, MinFloatValues +import gc + +from spellbind.float_values import FloatConstant, MaxFloatValues, MinFloatValues, FloatVariable from spellbind.values import SimpleVariable @@ -49,3 +51,28 @@ def test_min_float_values_with_literals(): a.value = 5.1 assert min_val.value == 5.1 + + +def test_add_float_values_keeps_reference(): + v0 = FloatVariable(1.5) + v1 = FloatVariable(2.5) + v2 = v0 + v1 + assert len(v0._on_change._subscriptions) == 1 + gc.collect() + + v0.value = 3.5 + assert len(v0._on_change._subscriptions) == 1 + + +def test_add_int_values_garbage_collected(): + v0 = FloatVariable(1.5) + v1 = FloatVariable(2.5) + v2 = v0 + v1 + assert len(v0._on_change._subscriptions) == 1 + assert len(v1._on_change._subscriptions) == 1 + v2 = None + gc.collect() + v0.value = 3.5 # trigger removal of weak references + v1.value = 4.5 # trigger removal of weak references + assert len(v0._on_change._subscriptions) == 0 + assert len(v1._on_change._subscriptions) == 0 diff --git a/tests/test_values/test_int_values/test_add_int_values.py b/tests/test_values/test_int_values/test_add_int_values.py index a70ab5a..16dede7 100644 --- a/tests/test_values/test_int_values/test_add_int_values.py +++ b/tests/test_values/test_int_values/test_add_int_values.py @@ -1,3 +1,5 @@ +import gc + from spellbind.float_values import FloatVariable from spellbind.int_values import IntVariable @@ -56,3 +58,29 @@ def test_add_float_plus_int_value(): v1.value = 4 assert v2.value == 7.5 + + +def test_add_int_values_keeps_reference(): + v0 = IntVariable(1) + v1 = IntVariable(2) + v2 = v0 + v1 + assert len(v0._on_change._subscriptions) == 1 + gc.collect() + + v0.value = 3 + v1.value = 4 + assert len(v0._on_change._subscriptions) == 1 + + +def test_add_int_values_garbage_collected(): + v0 = IntVariable(1) + v1 = IntVariable(2) + v2 = v0 + v1 + assert len(v0._on_change._subscriptions) == 1 + assert len(v1._on_change._subscriptions) == 1 + v2 = None + gc.collect() + v0.value = 3 # trigger removal of weak references + v1.value = 4 # trigger removal of weak references + assert len(v0._on_change._subscriptions) == 0 + assert len(v1._on_change._subscriptions) == 0