diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 5821b660f3..1af159108a 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -99,7 +99,14 @@ from flytekit.models.project import Project from flytekit.remote.backfill import create_backfill_workflow from flytekit.remote.data import download_literal -from flytekit.remote.entities import FlyteLaunchPlan, FlyteNode, FlyteTask, FlyteTaskNode, FlyteWorkflow +from flytekit.remote.entities import ( + FlyteBranchNode, + FlyteLaunchPlan, + FlyteNode, + FlyteTask, + FlyteTaskNode, + FlyteWorkflow, +) from flytekit.remote.executions import FlyteNodeExecution, FlyteTaskExecution, FlyteWorkflowExecution from flytekit.remote.interface import TypedInterface from flytekit.remote.lazy_entity import LazyEntity @@ -2693,11 +2700,23 @@ def sync_node_execution( # Handle the case where it's a branch node elif execution._node.branch_node is not None: - logger.info( - "Skipping branch node execution for now - branch nodes will " - "not have inputs and outputs filled in" - ) - return execution + sub_flyte_workflow = typing.cast(FlyteBranchNode, execution._node.flyte_entity) + sub_node_mapping = {} + if sub_flyte_workflow.if_else.case.then_node: + then_node = sub_flyte_workflow.if_else.case.then_node + sub_node_mapping[then_node.id] = then_node + if sub_flyte_workflow.if_else.other: + for case in sub_flyte_workflow.if_else.other: + then_node = case.then_node + sub_node_mapping[then_node.id] = then_node + if sub_flyte_workflow.if_else.else_node: + else_node = sub_flyte_workflow.if_else.else_node + sub_node_mapping[else_node.id] = else_node + + execution._underlying_node_executions = [ + self.sync_node_execution(FlyteNodeExecution.promote_from_model(cne), sub_node_mapping) + for cne in child_node_executions + ] else: logger.error(f"NE {execution} undeterminable, {type(execution._node)}, {execution._node}") raise ValueError(f"Node execution undeterminable, entity has type {type(execution._node)}") @@ -2839,15 +2858,19 @@ def _assign_inputs_and_outputs( self, execution: typing.Union[FlyteWorkflowExecution, FlyteNodeExecution, FlyteTaskExecution], execution_data, - interface: TypedInterface, + interface: typing.Optional[TypedInterface] = None, ): """Helper for assigning synced inputs and outputs to an execution object.""" input_literal_map = self._get_input_literal_map(execution_data) - execution._inputs = LiteralsResolver(input_literal_map.literals, interface.inputs, self.context) + execution._inputs = LiteralsResolver( + input_literal_map.literals, interface.inputs if interface else None, self.context + ) if execution.is_done and not execution.error: output_literal_map = self._get_output_literal_map(execution_data) - execution._outputs = LiteralsResolver(output_literal_map.literals, interface.outputs, self.context) + execution._outputs = LiteralsResolver( + output_literal_map.literals, interface.outputs if interface else None, self.context + ) return execution def _get_input_literal_map(self, execution_data: ExecutionDataResponse) -> literal_models.LiteralMap: