From 75dfa9220b90b5362b81fc6cb6ecaa443e28786b Mon Sep 17 00:00:00 2001 From: catfish Date: Thu, 14 Aug 2025 22:57:33 +0800 Subject: [PATCH 1/6] feat: add timeout parameter to ContainerTask for controlling execution duration Signed-off-by: catfish --- flytekit/core/container_task.py | 26 ++++++-- .../flytekit/unit/core/test_container_task.py | 62 +++++++++++++++++++ 2 files changed, 84 insertions(+), 4 deletions(-) diff --git a/flytekit/core/container_task.py b/flytekit/core/container_task.py index 2d99f3b8c0..1077a87dba 100644 --- a/flytekit/core/container_task.py +++ b/flytekit/core/container_task.py @@ -1,3 +1,4 @@ +import datetime import os import typing from enum import Enum @@ -61,6 +62,7 @@ def __init__( pod_template_name: Optional[str] = None, local_logs: bool = False, resources: Optional[Resources] = None, + timeout: Optional[typing.Union[float, int, "datetime.timedelta"]] = None, **kwargs, ): sec_ctx = None @@ -103,6 +105,7 @@ def __init__( ) self.pod_template = pod_template self.local_logs = local_logs + self._timeout = timeout @property def resources(self) -> ResourceSpec: @@ -279,14 +282,20 @@ def execute(self, **kwargs) -> LiteralMap: container = client.containers.run( self._image, command=commands, remove=True, volumes=volume_bindings, detach=True ) - # Wait for the container to finish the task - # TODO: Add a 'timeout' parameter to control the max wait time for the container to finish the task. + + # Wait for the container to finish the task, with timeout if specified + timeout_seconds = None + if self._timeout is not None: + if isinstance(self._timeout, datetime.timedelta): + timeout_seconds = self._timeout.total_seconds() + else: + timeout_seconds = float(self._timeout) if self.local_logs: for log in container.logs(stream=True): print(f"[Local Container] {log.strip()}") - container.wait() + container.wait(timeout=timeout_seconds) output_dict = self._get_output_dict(output_directory) outputs_literal_map = TypeEngine.dict_to_literal_map(ctx, output_dict) @@ -330,8 +339,17 @@ def _get_container(self, settings: SerializationSettings) -> _task_model.Contain def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod: if self.pod_template is None: return None + pod_spec = _serialize_pod_spec(self.pod_template, self._get_container(settings), settings) + if self._timeout is not None: + import datetime + + if isinstance(self._timeout, datetime.timedelta): + timeout_seconds = int(self._timeout.total_seconds()) + else: + timeout_seconds = int(float(self._timeout)) + pod_spec["activeDeadlineSeconds"] = timeout_seconds return _task_model.K8sPod( - pod_spec=_serialize_pod_spec(self.pod_template, self._get_container(settings), settings), + pod_spec=pod_spec, metadata=_task_model.K8sObjectMetadata( labels=self.pod_template.labels, annotations=self.pod_template.annotations, diff --git a/tests/flytekit/unit/core/test_container_task.py b/tests/flytekit/unit/core/test_container_task.py index 1281a9ec14..d45f4a6767 100644 --- a/tests/flytekit/unit/core/test_container_task.py +++ b/tests/flytekit/unit/core/test_container_task.py @@ -1,7 +1,11 @@ import os import sys +import time +import docker + from collections import OrderedDict from typing import Tuple +from datetime import timedelta import pytest from kubernetes.client.models import ( @@ -238,3 +242,61 @@ def test_container_task_image_spec(mock_image_spec_builder): pod = ct.get_k8s_pod(default_serialization_settings) assert pod.pod_spec["containers"][0]["image"] == image_spec_1.image_name() assert pod.pod_spec["containers"][1]["image"] == image_spec_2.image_name() + +def test_container_task_timeout(): + ct_with_timeout = ContainerTask( + name="timeout-test", + image="busybox", + command=["sleep", "5"], + timeout=1, + ) + + with pytest.raises((docker.errors.APIError, Exception)): + ct_with_timeout.execute() + + ct_with_timedelta = ContainerTask( + name="timedelta-timeout-test", + image="busybox", + command=["sleep", "2"], + timeout=timedelta(seconds=1), + ) + + with pytest.raises((docker.errors.APIError, Exception)): + ct_with_timedelta.execute() + + +def test_container_task_timeout_k8s_serialization(): + from datetime import timedelta + + ps = V1PodSpec( + containers=[], tolerations=[V1Toleration(effect="NoSchedule", key="nvidia.com/gpu", operator="Exists")] + ) + pt = PodTemplate(pod_spec=ps, labels={"test": "timeout"}) + + ct_numeric = ContainerTask( + name="timeout-k8s-test", + image="busybox", + command=["echo", "hello"], + pod_template=pt, + timeout=60, + ) + + default_image = Image(name="default", fqn="docker.io/xyz", tag="some-git-hash") + default_image_config = ImageConfig(default_image=default_image) + default_serialization_settings = SerializationSettings( + project="p", domain="d", version="v", image_config=default_image_config + ) + + k8s_pod = ct_numeric.get_k8s_pod(default_serialization_settings) + assert k8s_pod.pod_spec["activeDeadlineSeconds"] == 60 + + ct_timedelta = ContainerTask( + name="timeout-k8s-timedelta-test", + image="busybox", + command=["echo", "hello"], + pod_template=pt, + timeout=timedelta(minutes=2), + ) + + k8s_pod_timedelta = ct_timedelta.get_k8s_pod(default_serialization_settings) + assert k8s_pod_timedelta.pod_spec["activeDeadlineSeconds"] == 120 From 27668567ee75f5227537491eae1b1dd6cc3bfb62 Mon Sep 17 00:00:00 2001 From: catfish Date: Thu, 14 Aug 2025 23:32:16 +0800 Subject: [PATCH 2/6] test: update container task timeout tests with input/output directories Signed-off-by: catfish --- .../flytekit/unit/core/test_container_task.py | 27 ++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/tests/flytekit/unit/core/test_container_task.py b/tests/flytekit/unit/core/test_container_task.py index d45f4a6767..b7e43c5918 100644 --- a/tests/flytekit/unit/core/test_container_task.py +++ b/tests/flytekit/unit/core/test_container_task.py @@ -246,27 +246,33 @@ def test_container_task_image_spec(mock_image_spec_builder): def test_container_task_timeout(): ct_with_timeout = ContainerTask( name="timeout-test", + input_data_dir="/var/inputs", + output_data_dir="/var/outputs", image="busybox", - command=["sleep", "5"], + command=["sleep", "100"], timeout=1, ) + + with pytest.raises((docker.errors.APIError, Exception)): ct_with_timeout.execute() ct_with_timedelta = ContainerTask( name="timedelta-timeout-test", + input_data_dir="/var/inputs", + output_data_dir="/var/outputs", image="busybox", - command=["sleep", "2"], + command=["sleep", "100"], timeout=timedelta(seconds=1), ) + with pytest.raises((docker.errors.APIError, Exception)): ct_with_timedelta.execute() def test_container_task_timeout_k8s_serialization(): - from datetime import timedelta ps = V1PodSpec( containers=[], tolerations=[V1Toleration(effect="NoSchedule", key="nvidia.com/gpu", operator="Exists")] @@ -300,3 +306,18 @@ def test_container_task_timeout_k8s_serialization(): k8s_pod_timedelta = ct_timedelta.get_k8s_pod(default_serialization_settings) assert k8s_pod_timedelta.pod_spec["activeDeadlineSeconds"] == 120 + +def test_container_task_no_timeout(): + + ct = ContainerTask( + name="no-timeout-task", + input_data_dir="/var/inputs", + output_data_dir="/var/outputs", + image="busybox", + command=["sleep", "1"], + timeout=500, + ) + + + + ct.execute() \ No newline at end of file From 17c559c6c9c9abfcafba4045c243f6736b6dfd21 Mon Sep 17 00:00:00 2001 From: catfish Date: Thu, 14 Aug 2025 23:51:18 +0800 Subject: [PATCH 3/6] refactor: simplify timeout handling in ContainerTask to use only timedelta Signed-off-by: catfish --- flytekit/core/container_task.py | 14 ++----- .../flytekit/unit/core/test_container_task.py | 39 ++----------------- 2 files changed, 7 insertions(+), 46 deletions(-) diff --git a/flytekit/core/container_task.py b/flytekit/core/container_task.py index 1077a87dba..6201e0f5e3 100644 --- a/flytekit/core/container_task.py +++ b/flytekit/core/container_task.py @@ -62,7 +62,7 @@ def __init__( pod_template_name: Optional[str] = None, local_logs: bool = False, resources: Optional[Resources] = None, - timeout: Optional[typing.Union[float, int, "datetime.timedelta"]] = None, + timeout: Optional["datetime.timedelta"] = None, **kwargs, ): sec_ctx = None @@ -286,10 +286,7 @@ def execute(self, **kwargs) -> LiteralMap: # Wait for the container to finish the task, with timeout if specified timeout_seconds = None if self._timeout is not None: - if isinstance(self._timeout, datetime.timedelta): - timeout_seconds = self._timeout.total_seconds() - else: - timeout_seconds = float(self._timeout) + timeout_seconds = self._timeout.total_seconds() if self.local_logs: for log in container.logs(stream=True): @@ -341,12 +338,7 @@ def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod: return None pod_spec = _serialize_pod_spec(self.pod_template, self._get_container(settings), settings) if self._timeout is not None: - import datetime - - if isinstance(self._timeout, datetime.timedelta): - timeout_seconds = int(self._timeout.total_seconds()) - else: - timeout_seconds = int(float(self._timeout)) + timeout_seconds = int(self._timeout.total_seconds()) pod_spec["activeDeadlineSeconds"] = timeout_seconds return _task_model.K8sPod( pod_spec=pod_spec, diff --git a/tests/flytekit/unit/core/test_container_task.py b/tests/flytekit/unit/core/test_container_task.py index b7e43c5918..8d6d773d3b 100644 --- a/tests/flytekit/unit/core/test_container_task.py +++ b/tests/flytekit/unit/core/test_container_task.py @@ -244,29 +244,12 @@ def test_container_task_image_spec(mock_image_spec_builder): assert pod.pod_spec["containers"][1]["image"] == image_spec_2.image_name() def test_container_task_timeout(): - ct_with_timeout = ContainerTask( - name="timeout-test", - input_data_dir="/var/inputs", - output_data_dir="/var/outputs", - image="busybox", - command=["sleep", "100"], - timeout=1, - ) - - - - with pytest.raises((docker.errors.APIError, Exception)): - ct_with_timeout.execute() - ct_with_timedelta = ContainerTask( name="timedelta-timeout-test", - input_data_dir="/var/inputs", - output_data_dir="/var/outputs", image="busybox", command=["sleep", "100"], timeout=timedelta(seconds=1), ) - with pytest.raises((docker.errors.APIError, Exception)): ct_with_timedelta.execute() @@ -278,24 +261,13 @@ def test_container_task_timeout_k8s_serialization(): containers=[], tolerations=[V1Toleration(effect="NoSchedule", key="nvidia.com/gpu", operator="Exists")] ) pt = PodTemplate(pod_spec=ps, labels={"test": "timeout"}) - - ct_numeric = ContainerTask( - name="timeout-k8s-test", - image="busybox", - command=["echo", "hello"], - pod_template=pt, - timeout=60, - ) - + default_image = Image(name="default", fqn="docker.io/xyz", tag="some-git-hash") default_image_config = ImageConfig(default_image=default_image) default_serialization_settings = SerializationSettings( project="p", domain="d", version="v", image_config=default_image_config ) - k8s_pod = ct_numeric.get_k8s_pod(default_serialization_settings) - assert k8s_pod.pod_spec["activeDeadlineSeconds"] == 60 - ct_timedelta = ContainerTask( name="timeout-k8s-timedelta-test", image="busybox", @@ -308,16 +280,13 @@ def test_container_task_timeout_k8s_serialization(): assert k8s_pod_timedelta.pod_spec["activeDeadlineSeconds"] == 120 def test_container_task_no_timeout(): - - ct = ContainerTask( + ct_timedelta = ContainerTask( name="no-timeout-task", input_data_dir="/var/inputs", output_data_dir="/var/outputs", image="busybox", command=["sleep", "1"], - timeout=500, + timeout=timedelta(seconds=500), ) - - - ct.execute() \ No newline at end of file + ct_timedelta.execute() \ No newline at end of file From 323ff509043b0b40181a61c56d8389061bb550cd Mon Sep 17 00:00:00 2001 From: catfish Date: Fri, 15 Aug 2025 00:11:35 +0800 Subject: [PATCH 4/6] refactor: modify space by lint hint Signed-off-by: catfish --- tests/flytekit/unit/core/test_container_task.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/flytekit/unit/core/test_container_task.py b/tests/flytekit/unit/core/test_container_task.py index 8d6d773d3b..bda0f049ca 100644 --- a/tests/flytekit/unit/core/test_container_task.py +++ b/tests/flytekit/unit/core/test_container_task.py @@ -261,7 +261,7 @@ def test_container_task_timeout_k8s_serialization(): containers=[], tolerations=[V1Toleration(effect="NoSchedule", key="nvidia.com/gpu", operator="Exists")] ) pt = PodTemplate(pod_spec=ps, labels={"test": "timeout"}) - + default_image = Image(name="default", fqn="docker.io/xyz", tag="some-git-hash") default_image_config = ImageConfig(default_image=default_image) default_serialization_settings = SerializationSettings( @@ -285,8 +285,8 @@ def test_container_task_no_timeout(): input_data_dir="/var/inputs", output_data_dir="/var/outputs", image="busybox", - command=["sleep", "1"], + command=["sleep", "1"], timeout=timedelta(seconds=500), ) - - ct_timedelta.execute() \ No newline at end of file + + ct_timedelta.execute() From 358788be881d43ee4fddf42a005eb0d468fbc2d0 Mon Sep 17 00:00:00 2001 From: catfish Date: Sat, 16 Aug 2025 14:58:35 +0800 Subject: [PATCH 5/6] test: skip tests on unsupported platforms to improve CI reliability Signed-off-by: catfish --- tests/flytekit/unit/core/test_container_task.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/flytekit/unit/core/test_container_task.py b/tests/flytekit/unit/core/test_container_task.py index bda0f049ca..8695236674 100644 --- a/tests/flytekit/unit/core/test_container_task.py +++ b/tests/flytekit/unit/core/test_container_task.py @@ -243,6 +243,10 @@ def test_container_task_image_spec(mock_image_spec_builder): assert pod.pod_spec["containers"][0]["image"] == image_spec_1.image_name() assert pod.pod_spec["containers"][1]["image"] == image_spec_2.image_name() +@pytest.mark.skipif( + sys.platform in ["darwin", "win32"], + reason="Skip if running on windows or macos due to CI Docker environment setup failure", +) def test_container_task_timeout(): ct_with_timedelta = ContainerTask( name="timedelta-timeout-test", @@ -254,7 +258,10 @@ def test_container_task_timeout(): with pytest.raises((docker.errors.APIError, Exception)): ct_with_timedelta.execute() - +@pytest.mark.skipif( + sys.platform in ["darwin", "win32"], + reason="Skip if running on windows or macos due to CI Docker environment setup failure", +) def test_container_task_timeout_k8s_serialization(): ps = V1PodSpec( @@ -279,7 +286,12 @@ def test_container_task_timeout_k8s_serialization(): k8s_pod_timedelta = ct_timedelta.get_k8s_pod(default_serialization_settings) assert k8s_pod_timedelta.pod_spec["activeDeadlineSeconds"] == 120 -def test_container_task_no_timeout(): + +@pytest.mark.skipif( + sys.platform in ["darwin", "win32"], + reason="Skip if running on windows or macos due to CI Docker environment setup failure", +) +def test_container_task_within_timeout(): ct_timedelta = ContainerTask( name="no-timeout-task", input_data_dir="/var/inputs", From 0aa2567c07b4a2d052787a60bbaea31b0b327bfa Mon Sep 17 00:00:00 2001 From: catfish Date: Tue, 2 Sep 2025 00:04:13 +0800 Subject: [PATCH 6/6] feat: Set ContainerTask timeout through TaskMetadata for consistent semantics Ensure ContainerTask timeout behavior matches regular Python tasks by using TaskMetadata mechanism. Signed-off-by: catfish --- flytekit/core/container_task.py | 4 +- .../flytekit/unit/core/test_container_task.py | 67 ++++++++++++++++--- 2 files changed, 62 insertions(+), 9 deletions(-) diff --git a/flytekit/core/container_task.py b/flytekit/core/container_task.py index 6201e0f5e3..0ec6a709b0 100644 --- a/flytekit/core/container_task.py +++ b/flytekit/core/container_task.py @@ -76,6 +76,9 @@ def __init__( metadata = metadata or TaskMetadata() metadata.pod_template_name = pod_template_name + if timeout is not None: + metadata.timeout = timeout + super().__init__( task_type="raw-container", name=name, @@ -283,7 +286,6 @@ def execute(self, **kwargs) -> LiteralMap: self._image, command=commands, remove=True, volumes=volume_bindings, detach=True ) - # Wait for the container to finish the task, with timeout if specified timeout_seconds = None if self._timeout is not None: timeout_seconds = self._timeout.total_seconds() diff --git a/tests/flytekit/unit/core/test_container_task.py b/tests/flytekit/unit/core/test_container_task.py index 8695236674..ee189eea96 100644 --- a/tests/flytekit/unit/core/test_container_task.py +++ b/tests/flytekit/unit/core/test_container_task.py @@ -291,14 +291,65 @@ def test_container_task_timeout_k8s_serialization(): sys.platform in ["darwin", "win32"], reason="Skip if running on windows or macos due to CI Docker environment setup failure", ) -def test_container_task_within_timeout(): - ct_timedelta = ContainerTask( - name="no-timeout-task", - input_data_dir="/var/inputs", - output_data_dir="/var/outputs", +def test_container_task_timeout_in_metadata(): + from flytekit.core.base_task import TaskMetadata + + ct_with_timedelta = ContainerTask( + name="timeout-metadata-test", + image="busybox", + command=["echo", "hello"], + timeout=timedelta(minutes=5), + ) + + assert ct_with_timedelta.metadata.timeout == timedelta(minutes=5) + + # Test with custom metadata - timeout should be set in the provided metadata + custom_metadata = TaskMetadata(retries=3) + ct_with_custom_metadata = ContainerTask( + name="custom-metadata-timeout-test", image="busybox", - command=["sleep", "1"], - timeout=timedelta(seconds=500), + command=["echo", "hello"], + metadata=custom_metadata, + timeout=timedelta(seconds=30), ) - ct_timedelta.execute() + # Verify timeout is set in the custom metadata and retries are preserved + assert ct_with_custom_metadata.metadata.timeout == timedelta(seconds=30) + assert ct_with_custom_metadata.metadata.retries == 3 + + ct_without_timeout = ContainerTask( + name="no-timeout-test", + image="busybox", + command=["echo", "hello"] + ) + + assert ct_without_timeout.metadata.timeout is None + + +def test_container_task_timeout_serialization(): + ps = V1PodSpec( + containers=[], tolerations=[V1Toleration(effect="NoSchedule", key="nvidia.com/gpu", operator="Exists")] + ) + pt = PodTemplate(pod_spec=ps, labels={"test": "timeout"}) + + default_image = Image(name="default", fqn="docker.io/xyz", tag="some-git-hash") + default_image_config = ImageConfig(default_image=default_image) + default_serialization_settings = SerializationSettings( + project="p", domain="d", version="v", image_config=default_image_config + ) + + ct_with_timeout = ContainerTask( + name="timeout-serialization-test", + image="busybox", + command=["echo", "hello"], + pod_template=pt, + timeout=timedelta(minutes=10), + ) + + from flytekit.tools.translator import get_serializable_task + from collections import OrderedDict + + serialized_task = get_serializable_task(OrderedDict(), default_serialization_settings, ct_with_timeout) + + k8s_pod = ct_with_timeout.get_k8s_pod(default_serialization_settings) + assert k8s_pod.pod_spec["activeDeadlineSeconds"] == 600 # 10 minutes in seconds