diff --git a/flytekit/core/container_task.py b/flytekit/core/container_task.py index 2d99f3b8c0..0ec6a709b0 100644 --- a/flytekit/core/container_task.py +++ b/flytekit/core/container_task.py @@ -1,3 +1,4 @@ +import datetime import os import typing from enum import Enum @@ -61,6 +62,7 @@ def __init__( pod_template_name: Optional[str] = None, local_logs: bool = False, resources: Optional[Resources] = None, + timeout: Optional["datetime.timedelta"] = None, **kwargs, ): sec_ctx = None @@ -74,6 +76,9 @@ def __init__( metadata = metadata or TaskMetadata() metadata.pod_template_name = pod_template_name + if timeout is not None: + metadata.timeout = timeout + super().__init__( task_type="raw-container", name=name, @@ -103,6 +108,7 @@ def __init__( ) self.pod_template = pod_template self.local_logs = local_logs + self._timeout = timeout @property def resources(self) -> ResourceSpec: @@ -279,14 +285,16 @@ def execute(self, **kwargs) -> LiteralMap: container = client.containers.run( self._image, command=commands, remove=True, volumes=volume_bindings, detach=True ) - # Wait for the container to finish the task - # TODO: Add a 'timeout' parameter to control the max wait time for the container to finish the task. + + timeout_seconds = None + if self._timeout is not None: + timeout_seconds = self._timeout.total_seconds() if self.local_logs: for log in container.logs(stream=True): print(f"[Local Container] {log.strip()}") - container.wait() + container.wait(timeout=timeout_seconds) output_dict = self._get_output_dict(output_directory) outputs_literal_map = TypeEngine.dict_to_literal_map(ctx, output_dict) @@ -330,8 +338,12 @@ def _get_container(self, settings: SerializationSettings) -> _task_model.Contain def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod: if self.pod_template is None: return None + pod_spec = _serialize_pod_spec(self.pod_template, self._get_container(settings), settings) + if self._timeout is not None: + timeout_seconds = int(self._timeout.total_seconds()) + pod_spec["activeDeadlineSeconds"] = timeout_seconds return _task_model.K8sPod( - pod_spec=_serialize_pod_spec(self.pod_template, self._get_container(settings), settings), + pod_spec=pod_spec, metadata=_task_model.K8sObjectMetadata( labels=self.pod_template.labels, annotations=self.pod_template.annotations, diff --git a/tests/flytekit/unit/core/test_container_task.py b/tests/flytekit/unit/core/test_container_task.py index 1281a9ec14..ee189eea96 100644 --- a/tests/flytekit/unit/core/test_container_task.py +++ b/tests/flytekit/unit/core/test_container_task.py @@ -1,7 +1,11 @@ import os import sys +import time +import docker + from collections import OrderedDict from typing import Tuple +from datetime import timedelta import pytest from kubernetes.client.models import ( @@ -238,3 +242,114 @@ def test_container_task_image_spec(mock_image_spec_builder): pod = ct.get_k8s_pod(default_serialization_settings) assert pod.pod_spec["containers"][0]["image"] == image_spec_1.image_name() assert pod.pod_spec["containers"][1]["image"] == image_spec_2.image_name() + +@pytest.mark.skipif( + sys.platform in ["darwin", "win32"], + reason="Skip if running on windows or macos due to CI Docker environment setup failure", +) +def test_container_task_timeout(): + ct_with_timedelta = ContainerTask( + name="timedelta-timeout-test", + image="busybox", + command=["sleep", "100"], + timeout=timedelta(seconds=1), + ) + + with pytest.raises((docker.errors.APIError, Exception)): + ct_with_timedelta.execute() + +@pytest.mark.skipif( + sys.platform in ["darwin", "win32"], + reason="Skip if running on windows or macos due to CI Docker environment setup failure", +) +def test_container_task_timeout_k8s_serialization(): + + ps = V1PodSpec( + containers=[], tolerations=[V1Toleration(effect="NoSchedule", key="nvidia.com/gpu", operator="Exists")] + ) + pt = PodTemplate(pod_spec=ps, labels={"test": "timeout"}) + + default_image = Image(name="default", fqn="docker.io/xyz", tag="some-git-hash") + default_image_config = ImageConfig(default_image=default_image) + default_serialization_settings = SerializationSettings( + project="p", domain="d", version="v", image_config=default_image_config + ) + + ct_timedelta = ContainerTask( + name="timeout-k8s-timedelta-test", + image="busybox", + command=["echo", "hello"], + pod_template=pt, + timeout=timedelta(minutes=2), + ) + + k8s_pod_timedelta = ct_timedelta.get_k8s_pod(default_serialization_settings) + assert k8s_pod_timedelta.pod_spec["activeDeadlineSeconds"] == 120 + + +@pytest.mark.skipif( + sys.platform in ["darwin", "win32"], + reason="Skip if running on windows or macos due to CI Docker environment setup failure", +) +def test_container_task_timeout_in_metadata(): + from flytekit.core.base_task import TaskMetadata + + ct_with_timedelta = ContainerTask( + name="timeout-metadata-test", + image="busybox", + command=["echo", "hello"], + timeout=timedelta(minutes=5), + ) + + assert ct_with_timedelta.metadata.timeout == timedelta(minutes=5) + + # Test with custom metadata - timeout should be set in the provided metadata + custom_metadata = TaskMetadata(retries=3) + ct_with_custom_metadata = ContainerTask( + name="custom-metadata-timeout-test", + image="busybox", + command=["echo", "hello"], + metadata=custom_metadata, + timeout=timedelta(seconds=30), + ) + + # Verify timeout is set in the custom metadata and retries are preserved + assert ct_with_custom_metadata.metadata.timeout == timedelta(seconds=30) + assert ct_with_custom_metadata.metadata.retries == 3 + + ct_without_timeout = ContainerTask( + name="no-timeout-test", + image="busybox", + command=["echo", "hello"] + ) + + assert ct_without_timeout.metadata.timeout is None + + +def test_container_task_timeout_serialization(): + ps = V1PodSpec( + containers=[], tolerations=[V1Toleration(effect="NoSchedule", key="nvidia.com/gpu", operator="Exists")] + ) + pt = PodTemplate(pod_spec=ps, labels={"test": "timeout"}) + + default_image = Image(name="default", fqn="docker.io/xyz", tag="some-git-hash") + default_image_config = ImageConfig(default_image=default_image) + default_serialization_settings = SerializationSettings( + project="p", domain="d", version="v", image_config=default_image_config + ) + + ct_with_timeout = ContainerTask( + name="timeout-serialization-test", + image="busybox", + command=["echo", "hello"], + pod_template=pt, + timeout=timedelta(minutes=10), + ) + + from flytekit.tools.translator import get_serializable_task + from collections import OrderedDict + + serialized_task = get_serializable_task(OrderedDict(), default_serialization_settings, ct_with_timeout) + + k8s_pod = ct_with_timeout.get_k8s_pod(default_serialization_settings) + assert k8s_pod.pod_spec["activeDeadlineSeconds"] == 600 # 10 minutes in seconds