diff --git a/src/dflow/argo_objects.py b/src/dflow/argo_objects.py index 85c99668..bc9b4601 100644 --- a/src/dflow/argo_objects.py +++ b/src/dflow/argo_objects.py @@ -3,7 +3,6 @@ import logging import os import shutil -import tempfile import time from collections import UserDict, UserList from copy import deepcopy @@ -13,7 +12,7 @@ from .config import config, s3_config from .io import S3Artifact from .op_template import get_k8s_client -from .utils import download_artifact, get_key, upload_s3 +from .utils import TempDir, download_artifact, get_key, upload_s3 try: import kubernetes @@ -82,7 +81,7 @@ def __getattr__(self, key): if ((key == "value" and "value" not in self.data) or (key == "type" and "type" not in self.data)) and \ hasattr(self, "save_as_artifact"): - with tempfile.TemporaryDirectory() as tmpdir: + with TempDir() as tmpdir: try: download_artifact(self, path=tmpdir) fs = os.listdir(tmpdir) @@ -163,7 +162,7 @@ def modify_output_parameter( self.outputs.parameters[name].value = jsonpickle.dumps(value) if hasattr(self.outputs.parameters[name], "save_as_artifact"): - with tempfile.TemporaryDirectory() as tmpdir: + with TempDir() as tmpdir: path = tmpdir + "/" + name with open(path, "w") as f: f.write(jsonpickle.dumps(value)) diff --git a/src/dflow/io.py b/src/dflow/io.py index 8eae61b8..57bdf079 100644 --- a/src/dflow/io.py +++ b/src/dflow/io.py @@ -1,5 +1,4 @@ import json -import tempfile from collections import UserDict from copy import copy, deepcopy from typing import Any, Dict, List, Optional, Union @@ -7,7 +6,7 @@ from .common import (CustomArtifact, HTTPArtifact, LocalArtifact, S3Artifact, jsonpickle, param_errmsg, param_regex) from .config import config -from .utils import randstr, s3_config, upload_s3 +from .utils import TempDir, randstr, s3_config, upload_s3 try: from argo.workflows.client import (V1alpha1ArchiveStrategy, @@ -517,7 +516,7 @@ def convert_to_argo(self): path=self.path, _from=str(self.value)) else: - with tempfile.TemporaryDirectory() as tmpdir: + with TempDir() as tmpdir: path = tmpdir + "/" + self.name with open(path, "w") as f: f.write(jsonpickle.dumps(self.value)) diff --git a/src/dflow/utils.py b/src/dflow/utils.py index 519aade3..2ec27b3f 100644 --- a/src/dflow/utils.py +++ b/src/dflow/utils.py @@ -31,6 +31,14 @@ pass +class TempDir(tempfile.TemporaryDirectory): + def cleanup(self): + try: + return super().cleanup() + except Exception: + pass + + def get_key(artifact, raise_error=True): if hasattr(artifact, "s3") and hasattr(artifact.s3, "key"): return artifact.s3.key @@ -123,7 +131,7 @@ def download_artifact( if key[-4:] == ".tgz" and extract: path = os.path.join(path, os.path.basename(key)) tf = tarfile.open(path, "r:gz") - with tempfile.TemporaryDirectory() as tmpdir: + with TempDir() as tmpdir: tf.extractall(tmpdir) tf.close() @@ -219,7 +227,7 @@ def upload_artifact( if archive == "default": archive = config["archive_mode"] cwd = os.getcwd() - with tempfile.TemporaryDirectory() as tmpdir: + with TempDir() as tmpdir: if isinstance(path, dict) or (isinstance(path, list) and any( [isinstance(p, (list, dict)) for p in path])): pairs = flatten(path).items() @@ -320,7 +328,7 @@ def copy_artifact(src, dst, sort=False, **kwargs) -> S3Artifact: key=lambda item: item["order"])["order"] + 1 for item in src_catalog: item["order"] += offset - with tempfile.TemporaryDirectory() as tmpdir: + with TempDir() as tmpdir: catalog_dir = os.path.join(tmpdir, config["catalog_dir_name"]) os.makedirs(catalog_dir, exist_ok=True) fpath = os.path.join(catalog_dir, str(uuid.uuid4())) @@ -487,7 +495,7 @@ def catalog_of_artifact(art, storage_client=None, **kwargs) -> List[dict]: else: client = MinioClient(**kwargs) catalog = [] - with tempfile.TemporaryDirectory() as tmpdir: + with TempDir() as tmpdir: objs = client.list(prefix=key) if len(objs) == 1 and objs[0][-1] == "/": key = objs[0]