diff --git a/src/pruna/config/smash_config.py b/src/pruna/config/smash_config.py index 00a9865a..0acc1e12 100644 --- a/src/pruna/config/smash_config.py +++ b/src/pruna/config/smash_config.py @@ -84,7 +84,7 @@ def __init__( self.cache_dir_prefix = Path(cache_dir_prefix) if not self.cache_dir_prefix.exists(): self.cache_dir_prefix.mkdir(parents=True, exist_ok=True) - self.cache_dir = Path(tempfile.mkdtemp(dir=cache_dir_prefix)) + self.cache_dir = Path(tempfile.mkdtemp(dir=self.cache_dir_prefix)) self.save_fns: list[str] = [] self.load_fns: list[str] = [] @@ -217,8 +217,14 @@ def __eq__(self, other: Any) -> bool: def cleanup_cache_dir(self) -> None: """Clean up the cache directory.""" - if self.cache_dir.exists(): - shutil.rmtree(self.cache_dir) + if hasattr(self, "cache_dir") and self.cache_dir is not None: + cache_path = Path(self.cache_dir) + + if not isinstance(cache_path, Path): + raise TypeError(f"cache_dir must be path-like, got {type(self.cache_dir)}") + + if cache_path.exists() and cache_path.is_dir(): + shutil.rmtree(cache_path, ignore_errors=True) def reset_cache_dir(self) -> None: """Reset the cache directory."""