diff --git a/src/frequenz/sdk/timeseries/formulas/_functions.py b/src/frequenz/sdk/timeseries/formulas/_functions.py index 0530fe750..010de1f51 100644 --- a/src/frequenz/sdk/timeseries/formulas/_functions.py +++ b/src/frequenz/sdk/timeseries/formulas/_functions.py @@ -100,6 +100,14 @@ class Coalesce(Function[QuantityT]): """A function that returns the first non-None argument.""" num_subscribed: int = 0 + """Number of parameters currently subscribed to.""" + + num_samples: int = 0 + """Number of samples received since last subscription change. + + This only counts samples from parameters other than the last one, + and may indicate that the last parameter can be unsubscribed from. + """ @property @override @@ -122,15 +130,16 @@ async def __call__(self) -> Sample[QuantityT] | QuantityT | None: match arg: case Sample(timestamp, value): if value is not None: - # Found a non-None value, unsubscribe from subsequent params + # Found a non-None value if ctr < self.num_subscribed: - await self._unsubscribe_all_params_after(ctr) + self.num_samples += 1 + # Unsubscribe from last component when the + # other component streams are reasonably stable. + if self.num_samples >= 3: + await self._unsubscribe_last_param() return arg ts = timestamp case Quantity(): - # Found a non-None value, unsubscribe from subsequent params - if ctr < self.num_subscribed: - await self._unsubscribe_all_params_after(ctr) if ts is not None: return Sample(timestamp=ts, value=arg) return arg @@ -166,16 +175,18 @@ async def _subscribe_next_param(self) -> None: ) await self.params[self.num_subscribed].subscribe() self.num_subscribed += 1 + self.num_samples = 0 - async def _unsubscribe_all_params_after(self, index: int) -> None: - """Unsubscribe from parameters after the given index.""" - for param in self.params[index:]: + async def _unsubscribe_last_param(self) -> None: + """Unsubscribe from the last parameter.""" + if self.num_subscribed > 1: _logger.debug( "Coalesce unsubscribing from param: %s", - param, + self.num_subscribed, ) - await param.unsubscribe() - self.num_subscribed = index + await self.params[self.num_subscribed - 1].unsubscribe() + self.num_subscribed -= 1 + self.num_samples = 0 @dataclass diff --git a/tests/timeseries/_formulas/test_formula_validation.py b/tests/timeseries/_formulas/test_formula_validation.py new file mode 100644 index 000000000..fd2b76131 --- /dev/null +++ b/tests/timeseries/_formulas/test_formula_validation.py @@ -0,0 +1,185 @@ +# License: MIT +# Copyright © 2026 Frequenz Energy-as-a-Service GmbH + +"""Tests for the Formula implementation.""" + +from unittest.mock import Mock + +import pytest +from frequenz.quantities import Quantity + +from frequenz.sdk.timeseries.formulas._exceptions import FormulaSyntaxError +from frequenz.sdk.timeseries.formulas._parser import parse + + +@pytest.mark.parametrize( + ("formula_str", "parsed_formula_str"), + [ + ("#1", "[f](#1)"), + ("-(1+#1)", "[f](0.0 - (1.0 + #1))"), + ("1*(2+3)", "[f](1.0 * (2.0 + 3.0))"), + ], +) +async def test_parser_validation( + formula_str: str, + parsed_formula_str: str, +) -> None: + """Test formula parser validation.""" + try: + formula = parse( + name="f", + formula=formula_str, + create_method=Quantity, + telemetry_fetcher=Mock(), + ) + assert str(formula) == parsed_formula_str + except FormulaSyntaxError: + assert False, "Parser should not raise an error for this formula" + + +@pytest.mark.parametrize( + ("formula_str", "expected_error_line"), + [ + ( + "1++", + " ^ Expected expression", + ), + ( + "1**", + " ^ Expected expression", + ), + ( + "--1", + " ^ Expected expression", + ), + ( + "(", + " ^ Expected expression", + ), + ( + "(1", + "^ Unmatched parenthesis", + ), + ( + "max", + " ^ Expected '(' after function name", + ), + ( + "max()", + " ^ Expected argument", + ), + ( + "max(1(", + " ^ Expected ',' or ')'", + ), + ( + "max(1", + " ^ Unmatched parenthesis", + ), + ( + "foo", + "^^^ Unknown function name", + ), + ( + "foo(1)", + "^^^ Unknown function name", + ), + ( + "max(1,,2)", + " ^ Expected argument", + ), + ( + "1 2", + " ^ Unexpected token", + ), + ( + "1, 2", + " ^ Unexpected token", + ), + ( + "max(1, 2,)", + " ^ Expected argument", + ), + ( + "max(1, 2))", + " ^ Unexpected token", + ), + ( + "max(1, 2),", + " ^ Unexpected token", + ), + ], +) +async def test_parser_validation_errors( + formula_str: str, expected_error_line: str +) -> None: + """Test formula parser validation.""" + with pytest.raises(FormulaSyntaxError) as error: + _ = parse( + name="f", + formula=formula_str, + create_method=Quantity, + telemetry_fetcher=Mock(), + ) + + assert str(error.value) == ( + "Formula syntax error:\n" + f" Formula: {formula_str}\n" + f" {expected_error_line}" + ) + + +@pytest.mark.parametrize( + ("formula_str", "expected_error"), + [ + # Long formula with error near start -> Ellipsize end + ( + "max(coalesce(#1001, %1002, 0), coalesce(#1003, #1004, 0), coalesce(#1005, #1006, 0), coalesce(#1007, #1008, 0))", # noqa: E501 + "Formula syntax error:\n" + " Formula: max(coalesce(#1001, %1002, 0), coalesce(#1003, #1004, 0), coalesc ...\n" + " ^ Unexpected character", + ), + # Long formula with error near the end -> Ellipsize start + ( + "max(coalesce(#1001, #1002, 0), coalesce(#1003, #1004, 0), coalesce(#1005, #1006, 0), coalesce(#10.07, #1008, 0))", # noqa: E501 + "Formula syntax error:\n" + " Formula: ... 0), coalesce(#1005, #1006, 0), coalesce(#10.07, #1008, 0))\n" + " ^ Unexpected character", + ), + # Very long formula with error in the middle -> Ellipsize both sides + ( + "max(coalesce(#1001, #1002, 0), coalesce(#1003, #1004, 0), coalesce(#1005, #1006, 0), coalesce(#1007, #1008, 0)) :) " # noqa: E501 + "min(coalesce(#2001, #2002, 0), coalesce(#2003, #2004, 0), coalesce(#2005, #2006, 0), coalesce(#2007, #2008, 0))", # noqa: E501 + "Formula syntax error:\n" + " Formula: ... 005, #1006, 0), coalesce(#1007, #1008, 0)) :) min(coalesce(#2 ...\n" + " ^ Unexpected character", + ), + ], +) +async def test_parser_validation_errors_in_long_formulas( + formula_str: str, expected_error: str +) -> None: + """Test formula parser validation for long formulas.""" + with pytest.raises(FormulaSyntaxError) as error: + _ = parse( + name="f", + formula=formula_str, + create_method=Quantity, + telemetry_fetcher=Mock(), + ) + + assert str(error.value) == expected_error + assert all(len(line) <= 80 for line in str(error.value).splitlines()) + + +async def test_empty_formula() -> None: + """Test formula parser validation.""" + with pytest.raises(FormulaSyntaxError) as error: + _ = parse( + name="f", + formula="", + create_method=Quantity, + telemetry_fetcher=Mock(), + ) + + assert str(error.value) == "Empty formula" diff --git a/tests/timeseries/_formulas/test_formulas.py b/tests/timeseries/_formulas/test_formulas.py index d4b193867..ed355e91c 100644 --- a/tests/timeseries/_formulas/test_formulas.py +++ b/tests/timeseries/_formulas/test_formulas.py @@ -8,6 +8,7 @@ from collections import OrderedDict from collections.abc import Callable from datetime import datetime, timedelta +from typing import NamedTuple from unittest.mock import AsyncMock, MagicMock import async_solipsism @@ -17,7 +18,6 @@ from frequenz.quantities import Quantity from frequenz.sdk.timeseries import Sample -from frequenz.sdk.timeseries.formulas._exceptions import FormulaSyntaxError from frequenz.sdk.timeseries.formulas._formula import Formula, FormulaBuilder from frequenz.sdk.timeseries.formulas._parser import parse from frequenz.sdk.timeseries.formulas._resampled_stream_fetcher import ( @@ -344,180 +344,6 @@ async def test_max_min_coalesce(self) -> None: ) -class TestFormulaValidation: - """Tests for Formula validation.""" - - @pytest.mark.parametrize( - ("formula_str", "parsed_formula_str"), - [ - ("#1", "[f](#1)"), - ("-(1+#1)", "[f](0.0 - (1.0 + #1))"), - ("1*(2+3)", "[f](1.0 * (2.0 + 3.0))"), - ], - ) - async def test_parser_validation( - self, - formula_str: str, - parsed_formula_str: str, - ) -> None: - """Test formula parser validation.""" - try: - formula = parse( - name="f", - formula=formula_str, - create_method=Quantity, - telemetry_fetcher=MagicMock(spec=ResampledStreamFetcher), - ) - assert str(formula) == parsed_formula_str - except FormulaSyntaxError: - assert False, "Parser should not raise an error for this formula" - - @pytest.mark.parametrize( - ("formula_str", "expected_error_line"), - [ - ( - "1++", - " ^ Expected expression", - ), - ( - "1**", - " ^ Expected expression", - ), - ( - "--1", - " ^ Expected expression", - ), - ( - "(", - " ^ Expected expression", - ), - ( - "(1", - "^ Unmatched parenthesis", - ), - ( - "max", - " ^ Expected '(' after function name", - ), - ( - "max()", - " ^ Expected argument", - ), - ( - "max(1(", - " ^ Expected ',' or ')'", - ), - ( - "max(1", - " ^ Unmatched parenthesis", - ), - ( - "foo", - "^^^ Unknown function name", - ), - ( - "foo(1)", - "^^^ Unknown function name", - ), - ( - "max(1,,2)", - " ^ Expected argument", - ), - ( - "1 2", - " ^ Unexpected token", - ), - ( - "1, 2", - " ^ Unexpected token", - ), - ( - "max(1, 2,)", - " ^ Expected argument", - ), - ( - "max(1, 2))", - " ^ Unexpected token", - ), - ( - "max(1, 2),", - " ^ Unexpected token", - ), - ], - ) - async def test_parser_validation_errors( - self, formula_str: str, expected_error_line: str - ) -> None: - """Test formula parser validation.""" - with pytest.raises(FormulaSyntaxError) as error: - _ = parse( - name="f", - formula=formula_str, - create_method=Quantity, - telemetry_fetcher=MagicMock(spec=ResampledStreamFetcher), - ) - - assert str(error.value) == ( - "Formula syntax error:\n" - f" Formula: {formula_str}\n" - f" {expected_error_line}" - ) - - @pytest.mark.parametrize( - ("formula_str", "expected_error"), - [ - # Long formula with error near start -> Ellipsize end - ( - "max(coalesce(#1001, %1002, 0), coalesce(#1003, #1004, 0), coalesce(#1005, #1006, 0), coalesce(#1007, #1008, 0))", # noqa: E501 - "Formula syntax error:\n" - " Formula: max(coalesce(#1001, %1002, 0), coalesce(#1003, #1004, 0), coalesc ...\n" - " ^ Unexpected character", - ), - # Long formula with error near the end -> Ellipsize start - ( - "max(coalesce(#1001, #1002, 0), coalesce(#1003, #1004, 0), coalesce(#1005, #1006, 0), coalesce(#10.07, #1008, 0))", # noqa: E501 - "Formula syntax error:\n" - " Formula: ... 0), coalesce(#1005, #1006, 0), coalesce(#10.07, #1008, 0))\n" - " ^ Unexpected character", - ), - # Very long formula with error in the middle -> Ellipsize both sides - ( - "max(coalesce(#1001, #1002, 0), coalesce(#1003, #1004, 0), coalesce(#1005, #1006, 0), coalesce(#1007, #1008, 0)) :) " # noqa: E501 - "min(coalesce(#2001, #2002, 0), coalesce(#2003, #2004, 0), coalesce(#2005, #2006, 0), coalesce(#2007, #2008, 0))", # noqa: E501 - "Formula syntax error:\n" - " Formula: ... 005, #1006, 0), coalesce(#1007, #1008, 0)) :) min(coalesce(#2 ...\n" - " ^ Unexpected character", - ), - ], - ) - async def test_parser_validation_errors_in_long_formulas( - self, formula_str: str, expected_error: str - ) -> None: - """Test formula parser validation for long formulas.""" - with pytest.raises(FormulaSyntaxError) as error: - _ = parse( - name="f", - formula=formula_str, - create_method=Quantity, - telemetry_fetcher=MagicMock(spec=ResampledStreamFetcher), - ) - - assert str(error.value) == expected_error - assert all(len(line) <= 80 for line in str(error.value).splitlines()) - - async def test_empty_formula(self) -> None: - """Test formula parser validation.""" - with pytest.raises(FormulaSyntaxError) as error: - _ = parse( - name="f", - formula="", - create_method=Quantity, - telemetry_fetcher=MagicMock(spec=ResampledStreamFetcher), - ) - - assert str(error.value) == "Empty formula" - - class TestFormulaComposition: """Tests for formula channels.""" @@ -784,8 +610,9 @@ async def test_coalesce(self) -> None: ([None, None, 15.0], None), ([None, None, 15.0], 15.0), ([10.0, None, 15.0], 50.0), - ([None, None, 15.0], None), - ([None, None, 15.0], None), + # Subscription to c5 was kept because we only unsubscribe after 3 samples + ([None, None, 15.0], 15.0), + ([None, None, 15.0], 15.0), ([None, None, 15.0], 15.0), ([None, None, None], None), ], @@ -907,3 +734,129 @@ async def test_compound(self) -> None: ([15.0, 17.0, None, 5.0], None), ], ) + + +class TestCoalesceFunction: + """Test coalesce function subscribe/unsubscribe behavior.""" + + class CoalesceSample(NamedTuple): + """Helper class to represent expected behavior of coalesce function.""" + + values: list[float | None] + expected_subscriptions: list[bool] + + async def run_test( # pylint: disable=too-many-locals + self, + formula_str: str, + samples: list[CoalesceSample], + ) -> None: + """Run a test with the specs provided.""" + # Component IDs are 0, 1, 2 for convenience. + channels: list[Broadcast[Sample[Quantity]]] = [ + Broadcast(name=str(num)) for num in range(3) + ] + senders = [channel.new_sender() for channel in channels] + receivers: list[Receiver[Sample[Quantity]] | None] = [None, None, None] + + def new_receiver(component_id: ComponentId) -> Receiver[Sample[Quantity]]: + """Create a new receiver, overwriting any existing one. + + When Coalesce unsubscribes, it closes its receiver. + """ + comp_id = int(component_id) + receiver = channels[comp_id].new_receiver() + receivers[comp_id] = receiver + return receiver + + telem_fetcher = MagicMock(spec=ResampledStreamFetcher) + telem_fetcher.fetch_stream = AsyncMock(side_effect=new_receiver) + formula = parse( + name="f2", + formula=formula_str, + create_method=Quantity, + telemetry_fetcher=telem_fetcher, + ) + + result_chan = formula.new_receiver() + await asyncio.sleep(0.1) + now = datetime.now() + + async def send_sample(values: list[float | None]) -> None: + nonlocal now + now += timedelta(seconds=1) + _ = await asyncio.gather( + *[ + senders[comp_id].send( + Sample(now, None if not value else Quantity(value)) + ) + for comp_id, value in enumerate(values) + ] + ) + _ = await result_chan.receive() + + for sample in samples: + await send_sample(sample.values) + active_subscriptions = [ + receiver is not None and not getattr(receiver, "_closed", True) + for receiver in receivers + ] + assert active_subscriptions == sample.expected_subscriptions + + await formula.stop() + + async def test_coalesce_subscribe(self) -> None: + """Test coalesce subscribes when None values are encountered.""" + await self.run_test( + "COALESCE(#0, #1, #2, 0.0)", + [ + self.CoalesceSample( + values=[10.0, None, None], + expected_subscriptions=[True, False, False], + ), + # No need to subscribe unless stream #1 gives None + self.CoalesceSample( + values=[10.0, 12.0, 15.0], + expected_subscriptions=[True, False, False], + ), + # If None is encountered, one subscription is added per sample + self.CoalesceSample( + values=[None, None, 15.0], + expected_subscriptions=[True, True, False], + ), + self.CoalesceSample( + values=[None, None, 15.0], + expected_subscriptions=[True, True, True], + ), + ], + ) + + async def test_coalesce_unsubscribe(self) -> None: + """Test coalesce only unsubscribes after 3 samples.""" + await self.run_test( + "COALESCE(#0, #1, #2, 0.0)", + [ + # First subscription is added before the first sample. + # Every sample can add one subscription. + self.CoalesceSample( + values=[None, None, 15.0], + expected_subscriptions=[True, True, False], + ), + self.CoalesceSample( + values=[None, None, 15.0], + expected_subscriptions=[True, True, True], + ), + # After 3 samples, the last subscription is dropped. + self.CoalesceSample( + values=[None, 12.0, 15.0], + expected_subscriptions=[True, True, True], + ), + self.CoalesceSample( + values=[10.0, None, 15.0], + expected_subscriptions=[True, True, True], + ), + self.CoalesceSample( + values=[None, 12.0, 15.0], + expected_subscriptions=[True, True, False], + ), + ], + )