diff --git a/flytekit/remote/entities.py b/flytekit/remote/entities.py index 73bc26b360..2eeae3aaf1 100644 --- a/flytekit/remote/entities.py +++ b/flytekit/remote/entities.py @@ -730,7 +730,9 @@ def promote_from_model( tasks: Optional[Dict[Identifier, FlyteTask]] = None, node_launch_plans: Optional[Dict[Identifier, launch_plan_models.LaunchPlanSpec]] = None, ) -> FlyteWorkflow: - base_model_non_system_nodes = cls.get_non_system_nodes(base_model.nodes) + base_model_non_system_nodes = cls.get_non_system_nodes( + base_model.nodes + ([base_model.failure_node] if base_model.failure_node is not None else []) + ) node_map = {} converted_sub_workflows = {} diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 18481d9e69..a17fc6f62b 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -1358,3 +1358,12 @@ def test_run_wf_with_resource_requests_override(register): ], limits=[], ) + + +def test_workflow_with_failure_node(): + execution_id = run("with_failure_node.py", "wf") + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) + execution = remote.fetch_execution(name=execution_id) + execution = remote.wait(execution=execution, timeout=datetime.timedelta(minutes=5)) + print("Execution Error:", execution.error) + assert execution.closure.phase == WorkflowExecutionPhase.FAILED, f"Execution failed with phase: {execution.closure.phase}" diff --git a/tests/flytekit/integration/remote/workflows/basic/with_failure_node.py b/tests/flytekit/integration/remote/workflows/basic/with_failure_node.py new file mode 100644 index 0000000000..f1a29d8f94 --- /dev/null +++ b/tests/flytekit/integration/remote/workflows/basic/with_failure_node.py @@ -0,0 +1,31 @@ +import typing + +import flytekit as fl +from flytekit import WorkflowFailurePolicy +from flytekit.types.error.error import FlyteError + + +@fl.task +def create_cluster(name: str): + print(f"Creating cluster: {name}") + + +@fl.task +def t1(a: int, b: str): + print(f"{a} {b}") + raise ValueError("Fail!") + + +@fl.task +def clean_up(name: str, err: typing.Optional[FlyteError] = None): + print(f"Deleting cluster {name} due to {err}") + + +@fl.workflow( + on_failure=clean_up, + failure_policy=WorkflowFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE, +) +def wf(name: str = "my_workflow"): + c = create_cluster(name=name) + t = t1(a=1, b="2") + c >> t diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index e3ae715c9a..b5709aff1d 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -120,6 +120,7 @@ def test_remote_fetch_execution(remote): ) mock_client = MagicMock() mock_client.get_execution.return_value = admin_workflow_execution + mock_client.get_workflow.return_value.closure.compiled_workflow.primary.template.failure_node = None remote._client = mock_client flyte_workflow_execution = remote.fetch_execution(name="n1") assert flyte_workflow_execution.id == admin_workflow_execution.id @@ -562,6 +563,7 @@ def mock_flyte_remote_client(): with patch("flytekit.remote.remote.FlyteRemote.client") as mock_flyte_remote_client: mock_flyte_remote_client.get_task.return_value.closure.compiled_task.template.sql = None mock_flyte_remote_client.get_task.return_value.closure.compiled_task.template.k8s_pod = None + mock_flyte_remote_client.get_workflow.return_value.closure.compiled_workflow.primary.template.failure_node = None yield mock_flyte_remote_client