diff --git a/iwf/command_results.py b/iwf/command_results.py index e24bb09..ee0a0ef 100644 --- a/iwf/command_results.py +++ b/iwf/command_results.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from typing import Any, Union -from iwf.errors import WorkflowDefinitionError +from iwf.errors import WorkflowDefinitionError, NotRegisteredError from iwf.iwf_api.models import ( ChannelRequestStatus, CommandResults as IdlCommandResults, @@ -10,6 +10,7 @@ ) from iwf.iwf_api.types import Unset from iwf.object_encoder import ObjectEncoder +from iwf.type_store import TypeStore @dataclass @@ -43,7 +44,7 @@ class CommandResults: def from_idl_command_results( idl_results: Union[Unset, IdlCommandResults], - internal_channel_types: dict[str, typing.Optional[type]], + internal_channel_types: TypeStore, signal_channel_types: dict[str, typing.Optional[type]], object_encoder: ObjectEncoder, ) -> CommandResults: @@ -58,18 +59,13 @@ def from_idl_command_results( if not isinstance(idl_results.inter_state_channel_results, Unset): for inter in idl_results.inter_state_channel_results: - val_type = internal_channel_types.get(inter.channel_name) - if val_type is None: - # fallback to assume it's prefix - # TODO use is_prefix to implement like Java SDK - for name, t in internal_channel_types.items(): - if inter.channel_name.startswith(name): - val_type = t - break - if val_type is None: + + try: + val_type = internal_channel_types.get_type(inter.channel_name) + except NotRegisteredError as exception: raise WorkflowDefinitionError( "internal channel is not registered: " + inter.channel_name - ) + ) from exception encoded = object_encoder.decode(inter.value, val_type) diff --git a/iwf/communication.py b/iwf/communication.py index 7b5a1f7..78d141a 100644 --- a/iwf/communication.py +++ b/iwf/communication.py @@ -1,6 +1,6 @@ from typing import Any, Optional, Union -from iwf.errors import WorkflowDefinitionError +from iwf.errors import WorkflowDefinitionError, NotRegisteredError from iwf.iwf_api.models import ( EncodedObject, InterStateChannelPublishing, @@ -9,10 +9,11 @@ ) from iwf.object_encoder import ObjectEncoder from iwf.state_movement import StateMovement +from iwf.type_store import TypeStore class Communication: - _internal_channel_type_store: dict[str, Optional[type]] + _internal_channel_type_store: TypeStore _signal_channel_type_store: dict[str, Optional[type]] _object_encoder: ObjectEncoder _to_publish_internal_channel: dict[str, list[EncodedObject]] @@ -22,7 +23,7 @@ class Communication: def __init__( self, - internal_channel_type_store: dict[str, Optional[type]], + internal_channel_type_store: TypeStore, signal_channel_type_store: dict[str, Optional[type]], object_encoder: ObjectEncoder, internal_channel_infos: Optional[WorkflowWorkerRpcRequestInternalChannelInfos], @@ -47,17 +48,12 @@ def trigger_state_execution(self, state: Union[str, type], state_input: Any = No self._state_movements.append(movement) def publish_to_internal_channel(self, channel_name: str, value: Any = None): - registered_type = self._internal_channel_type_store.get(channel_name) - - if registered_type is None: - for name, t in self._internal_channel_type_store.items(): - if channel_name.startswith(name): - registered_type = t - - if registered_type is None: + try: + registered_type = self._internal_channel_type_store.get_type(channel_name) + except NotRegisteredError as exception: raise WorkflowDefinitionError( f"InternalChannel channel_name is not defined {channel_name}" - ) + ) from exception if ( value is not None @@ -84,14 +80,11 @@ def get_to_trigger_state_movements(self) -> list[StateMovement]: return self._state_movements def get_internal_channel_size(self, channel_name): - registered_type = self._internal_channel_type_store.get(channel_name) - - if registered_type is None: - for name, t in self._internal_channel_type_store.items(): - if channel_name.startswith(name): - registered_type = t + is_type_registered = self._internal_channel_type_store.is_valid_name_or_prefix( + channel_name + ) - if registered_type is None: + if is_type_registered is False: raise WorkflowDefinitionError( f"InternalChannel channel_name is not defined {channel_name}" ) diff --git a/iwf/errors.py b/iwf/errors.py index 3c39cf2..3d6fa67 100644 --- a/iwf/errors.py +++ b/iwf/errors.py @@ -18,6 +18,10 @@ class InvalidArgumentError(Exception): pass +class NotRegisteredError(Exception): + pass + + class HttpError(RuntimeError): def __init__(self, status: int, err_resp: ErrorResponse): super().__init__(err_resp.detail) diff --git a/iwf/registry.py b/iwf/registry.py index 8619039..dc4cb0c 100644 --- a/iwf/registry.py +++ b/iwf/registry.py @@ -4,6 +4,7 @@ from iwf.errors import InvalidArgumentError, WorkflowDefinitionError from iwf.persistence_schema import PersistenceFieldType from iwf.rpc import RPCInfo +from iwf.type_store import TypeStore, Type from iwf.workflow import ObjectWorkflow, get_workflow_type from iwf.workflow_state import WorkflowState, get_state_id @@ -12,7 +13,7 @@ class Registry: _workflow_store: dict[str, ObjectWorkflow] _starting_state_store: dict[str, WorkflowState] _state_store: dict[str, dict[str, WorkflowState]] - _internal_channel_type_store: dict[str, dict[str, Optional[type]]] + _internal_channel_type_store: dict[str, TypeStore] _signal_channel_type_store: dict[str, dict[str, Optional[type]]] _data_attribute_types: dict[str, dict[str, Optional[type]]] _rpc_infos: dict[str, dict[str, RPCInfo]] @@ -63,7 +64,7 @@ def get_workflow_state_with_check( def get_state_store(self, wf_type: str) -> dict[str, WorkflowState]: return self._state_store[wf_type] - def get_internal_channel_types(self, wf_type: str) -> dict[str, Optional[type]]: + def get_internal_channel_type_store(self, wf_type: str) -> TypeStore: return self._internal_channel_type_store[wf_type] def get_signal_channel_types(self, wf_type: str) -> dict[str, Optional[type]]: @@ -83,13 +84,17 @@ def _register_workflow_type(self, wf: ObjectWorkflow): def _register_internal_channels(self, wf: ObjectWorkflow): wf_type = get_workflow_type(wf) - types: dict[str, Optional[type]] = {} + + if wf_type not in self._internal_channel_type_store: + self._internal_channel_type_store[wf_type] = TypeStore( + Type.INTERNAL_CHANNEL + ) + for method in wf.get_communication_schema().communication_methods: if method.method_type == CommunicationMethodType.InternalChannel: - types[method.name] = method.value_type - # TODO use is_prefix to implement like Java SDK - # - self._internal_channel_type_store[wf_type] = types + self._internal_channel_type_store[wf_type].add_internal_channel_def( + method + ) def _register_signal_channels(self, wf: ObjectWorkflow): wf_type = get_workflow_type(wf) diff --git a/iwf/tests/test_internal_channel.py b/iwf/tests/test_internal_channel.py index 727cdb4..ae56625 100644 --- a/iwf/tests/test_internal_channel.py +++ b/iwf/tests/test_internal_channel.py @@ -1,5 +1,6 @@ import inspect import time +import unittest from iwf.client import Client from iwf.command_request import CommandRequest, InternalChannelCommand @@ -133,8 +134,9 @@ def get_communication_schema(self) -> CommunicationSchema: client = Client(registry) -def test_internal_channel_workflow(): - wf_id = f"{inspect.currentframe().f_code.co_name}-{time.time_ns()}" +class TestConditionalComplete(unittest.TestCase): + def test_internal_channel_workflow(self): + wf_id = f"{inspect.currentframe().f_code.co_name}-{time.time_ns()}" - client.start_workflow(InternalChannelWorkflow, wf_id, 100, None) - client.get_simple_workflow_result_with_wait(wf_id, None) + client.start_workflow(InternalChannelWorkflow, wf_id, 100, None) + client.get_simple_workflow_result_with_wait(wf_id, None) diff --git a/iwf/tests/test_internal_channel_with_no_prefix_channel.py b/iwf/tests/test_internal_channel_with_no_prefix_channel.py new file mode 100644 index 0000000..9da6bc5 --- /dev/null +++ b/iwf/tests/test_internal_channel_with_no_prefix_channel.py @@ -0,0 +1,123 @@ +import inspect +import time +import unittest + +from iwf.client import Client +from iwf.command_request import CommandRequest, InternalChannelCommand +from iwf.command_results import CommandResults +from iwf.communication import Communication +from iwf.communication_schema import CommunicationMethod, CommunicationSchema +from iwf.persistence import Persistence +from iwf.state_decision import StateDecision +from iwf.state_schema import StateSchema +from iwf.tests.worker_server import registry +from iwf.workflow import ObjectWorkflow +from iwf.workflow_context import WorkflowContext +from iwf.workflow_state import T, WorkflowState + +internal_channel_name = "internal-channel-1" + +test_non_prefix_channel_name = "test-channel-" +test_non_prefix_channel_name_with_suffix = test_non_prefix_channel_name + "abc" + + +class InitState(WorkflowState[None]): + def execute( + self, + ctx: WorkflowContext, + input: T, + command_results: CommandResults, + persistence: Persistence, + communication: Communication, + ) -> StateDecision: + return StateDecision.multi_next_states( + WaitAnyWithPublishState, WaitAllThenPublishState + ) + + +class WaitAnyWithPublishState(WorkflowState[None]): + def wait_until( + self, + ctx: WorkflowContext, + input: T, + persistence: Persistence, + communication: Communication, + ) -> CommandRequest: + # Trying to publish to a non-existing channel; this would only work if test_channel_name_non_prefix was defined as a prefix channel + communication.publish_to_internal_channel( + test_non_prefix_channel_name_with_suffix, "str-value-for-prefix" + ) + return CommandRequest.for_any_command_completed( + InternalChannelCommand.by_name(internal_channel_name), + ) + + def execute( + self, + ctx: WorkflowContext, + input: T, + command_results: CommandResults, + persistence: Persistence, + communication: Communication, + ) -> StateDecision: + return StateDecision.graceful_complete_workflow() + + +class WaitAllThenPublishState(WorkflowState[None]): + def wait_until( + self, + ctx: WorkflowContext, + input: T, + persistence: Persistence, + communication: Communication, + ) -> CommandRequest: + return CommandRequest.for_all_command_completed( + InternalChannelCommand.by_name(test_non_prefix_channel_name), + ) + + def execute( + self, + ctx: WorkflowContext, + input: T, + command_results: CommandResults, + persistence: Persistence, + communication: Communication, + ) -> StateDecision: + communication.publish_to_internal_channel(internal_channel_name, None) + return StateDecision.dead_end + + +class InternalChannelWorkflowWithNoPrefixChannel(ObjectWorkflow): + def get_workflow_states(self) -> StateSchema: + return StateSchema.with_starting_state( + InitState(), WaitAnyWithPublishState(), WaitAllThenPublishState() + ) + + def get_communication_schema(self) -> CommunicationSchema: + return CommunicationSchema.create( + CommunicationMethod.internal_channel_def(internal_channel_name, type(None)), + # Defining a standard channel (non-prefix) to make sure messages to the channel with a suffix added will not be accepted + CommunicationMethod.internal_channel_def(test_non_prefix_channel_name, str), + ) + + +wf = InternalChannelWorkflowWithNoPrefixChannel() +registry.add_workflow(wf) +client = Client(registry) + + +class TestInternalChannelWithNoPrefix(unittest.TestCase): + def test_internal_channel_workflow_with_no_prefix_channel(self): + wf_id = f"{inspect.currentframe().f_code.co_name}-{time.time_ns()}" + + client.start_workflow( + InternalChannelWorkflowWithNoPrefixChannel, wf_id, 5, None + ) + + with self.assertRaises(Exception) as context: + client.wait_for_workflow_completion(wf_id, None) + + self.assertIn("FAILED", context.exception.workflow_status) + self.assertIn( + f"WorkerExecutionError: InternalChannel channel_name is not defined {test_non_prefix_channel_name_with_suffix}", + context.exception.error_message, + ) diff --git a/iwf/type_store.py b/iwf/type_store.py new file mode 100644 index 0000000..e2a552e --- /dev/null +++ b/iwf/type_store.py @@ -0,0 +1,68 @@ +from typing import Optional +from enum import Enum + +from iwf.communication_schema import CommunicationMethod +from iwf.errors import WorkflowDefinitionError, NotRegisteredError + + +class Type(Enum): + INTERNAL_CHANNEL = 1 + # TODO: extend to other types + # DATA_ATTRIBUTE = 2 + # SIGNAL_CHANNEL = 3 + + +class TypeStore: + _class_type: Type + _name_to_type_store: dict[str, Optional[type]] + _prefix_to_type_store: dict[str, Optional[type]] + + def __init__(self, class_type: Type): + self._class_type = class_type + self._name_to_type_store = dict() + self._prefix_to_type_store = dict() + + def is_valid_name_or_prefix(self, name: str) -> bool: + t = self._do_get_type(name) + return t is not None + + def get_type(self, name: str) -> type: + t = self._do_get_type(name) + + if t is None: + raise NotRegisteredError(f"{self._class_type} not registered: {name}") + + return t + + def add_internal_channel_def(self, obj: CommunicationMethod): + if self._class_type != Type.INTERNAL_CHANNEL: + raise ValueError( + f"Cannot add internal channel definition to {self._class_type}" + ) + self._do_add_to_store(obj.is_prefix, obj.name, obj.value_type) + + def _do_get_type(self, name: str) -> Optional[type]: + if name in self._name_to_type_store: + return self._name_to_type_store[name] + + prefixes = self._prefix_to_type_store.keys() + + first = next((prefix for prefix in prefixes if name.startswith(prefix)), None) + + if first is None: + return None + + return self._prefix_to_type_store.get(first, None) + + def _do_add_to_store(self, is_prefix: bool, name: str, t: Optional[type]): + if is_prefix: + store = self._prefix_to_type_store + else: + store = self._name_to_type_store + + if name in store: + raise WorkflowDefinitionError( + f"{self._class_type} name/prefix {name} already exists" + ) + + store[name] = t diff --git a/iwf/worker_service.py b/iwf/worker_service.py index 5480962..520ff43 100644 --- a/iwf/worker_service.py +++ b/iwf/worker_service.py @@ -70,7 +70,7 @@ def handle_workflow_worker_rpc( wf_type = request.workflow_type rpc_info = self._registry.get_rpc_infos(wf_type)[request.rpc_name] - internal_channel_types = self._registry.get_internal_channel_types(wf_type) + internal_channel_types = self._registry.get_internal_channel_type_store(wf_type) signal_channel_types = self._registry.get_signal_channel_types(wf_type) data_attributes_types = self._registry.get_data_attribute_types(wf_type) @@ -141,7 +141,7 @@ def handle_workflow_state_wait_until( state = self._registry.get_workflow_state_with_check( wf_type, request.workflow_state_id ) - internal_channel_types = self._registry.get_internal_channel_types(wf_type) + internal_channel_types = self._registry.get_internal_channel_type_store(wf_type) signal_channel_types = self._registry.get_signal_channel_types(wf_type) data_attributes_types = self._registry.get_data_attribute_types(wf_type) @@ -187,7 +187,7 @@ def handle_workflow_state_execute( state = self._registry.get_workflow_state_with_check( wf_type, request.workflow_state_id ) - internal_channel_types = self._registry.get_internal_channel_types(wf_type) + internal_channel_types = self._registry.get_internal_channel_type_store(wf_type) signal_channel_types = self._registry.get_signal_channel_types(wf_type) data_attributes_types = self._registry.get_data_attribute_types(wf_type) context = _from_idl_context(request.context)