Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 32 additions & 9 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,14 @@
from flytekit.models.project import Project
from flytekit.remote.backfill import create_backfill_workflow
from flytekit.remote.data import download_literal
from flytekit.remote.entities import FlyteLaunchPlan, FlyteNode, FlyteTask, FlyteTaskNode, FlyteWorkflow
from flytekit.remote.entities import (
FlyteBranchNode,
FlyteLaunchPlan,
FlyteNode,
FlyteTask,
FlyteTaskNode,
FlyteWorkflow,
)
from flytekit.remote.executions import FlyteNodeExecution, FlyteTaskExecution, FlyteWorkflowExecution
from flytekit.remote.interface import TypedInterface
from flytekit.remote.lazy_entity import LazyEntity
Expand Down Expand Up @@ -2693,11 +2700,23 @@ def sync_node_execution(

# Handle the case where it's a branch node
elif execution._node.branch_node is not None:
logger.info(
"Skipping branch node execution for now - branch nodes will "
"not have inputs and outputs filled in"
)
return execution
sub_flyte_workflow = typing.cast(FlyteBranchNode, execution._node.flyte_entity)
sub_node_mapping = {}
if sub_flyte_workflow.if_else.case.then_node:
then_node = sub_flyte_workflow.if_else.case.then_node
sub_node_mapping[then_node.id] = then_node
if sub_flyte_workflow.if_else.other:
for case in sub_flyte_workflow.if_else.other:
then_node = case.then_node
sub_node_mapping[then_node.id] = then_node
if sub_flyte_workflow.if_else.else_node:
else_node = sub_flyte_workflow.if_else.else_node
sub_node_mapping[else_node.id] = else_node

execution._underlying_node_executions = [
self.sync_node_execution(FlyteNodeExecution.promote_from_model(cne), sub_node_mapping)
for cne in child_node_executions
]
else:
logger.error(f"NE {execution} undeterminable, {type(execution._node)}, {execution._node}")
raise ValueError(f"Node execution undeterminable, entity has type {type(execution._node)}")
Expand Down Expand Up @@ -2839,15 +2858,19 @@ def _assign_inputs_and_outputs(
self,
execution: typing.Union[FlyteWorkflowExecution, FlyteNodeExecution, FlyteTaskExecution],
execution_data,
interface: TypedInterface,
interface: typing.Optional[TypedInterface] = None,
):
"""Helper for assigning synced inputs and outputs to an execution object."""
input_literal_map = self._get_input_literal_map(execution_data)
execution._inputs = LiteralsResolver(input_literal_map.literals, interface.inputs, self.context)
execution._inputs = LiteralsResolver(
input_literal_map.literals, interface.inputs if interface else None, self.context
)

if execution.is_done and not execution.error:
output_literal_map = self._get_output_literal_map(execution_data)
execution._outputs = LiteralsResolver(output_literal_map.literals, interface.outputs, self.context)
execution._outputs = LiteralsResolver(
output_literal_map.literals, interface.outputs if interface else None, self.context
)
return execution

def _get_input_literal_map(self, execution_data: ExecutionDataResponse) -> literal_models.LiteralMap:
Expand Down
Loading