From 7604fecd2076ec57fbe0bebd33fcb1efed0074f7 Mon Sep 17 00:00:00 2001 From: Benjamin Gallusser Date: Wed, 10 Dec 2025 13:39:28 -0800 Subject: [PATCH 1/2] Set pythonic default model dir --- trackastra/model/model_api.py | 2 +- trackastra/model/pretrained.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/trackastra/model/model_api.py b/trackastra/model/model_api.py index 243bfaf..71abc51 100644 --- a/trackastra/model/model_api.py +++ b/trackastra/model/model_api.py @@ -172,7 +172,7 @@ def from_pretrained( Args: name: Name of pretrained model (e.g. "general_2d"). device: Device to run model on ("cuda", "mps", "cpu", "automatic" or None). - download_dir: Directory to download model to (defaults to ~/.cache/trackastra). + download_dir: Directory to download model. Default handled by platformdirs. Returns: Trackastra model instance. diff --git a/trackastra/model/pretrained.py b/trackastra/model/pretrained.py index 767da6a..eb8f7aa 100644 --- a/trackastra/model/pretrained.py +++ b/trackastra/model/pretrained.py @@ -2,10 +2,10 @@ import shutil import tempfile import zipfile -from importlib.resources import files from pathlib import Path import requests +from platformdirs import user_data_dir from tqdm import tqdm logger = logging.getLogger(__name__) @@ -60,7 +60,7 @@ def download(url: str, fname: Path): def download_pretrained(name: str, download_dir: Path | None = None): # TODO make safe, introduce versioning if download_dir is None: - download_dir = files("trackastra").joinpath(".models") + download_dir = Path(user_data_dir("trackastra")) / "models" else: download_dir = Path(download_dir) From cf3837c64437b82dd00f71fe436400529e25daed Mon Sep 17 00:00:00 2001 From: Benjamin Gallusser Date: Wed, 10 Dec 2025 14:27:27 -0800 Subject: [PATCH 2/2] Add platformdirs dep --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index be7dc3f..970d96f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,6 +38,7 @@ install_requires = tqdm requests psutil + platformdirs # zarr>=3 # Will need a python 3.11+ version python_requires = >=3.10 include_package_data = True