From ab21939087a04dbb9c49baa13c58892407ea945c Mon Sep 17 00:00:00 2001 From: KBolashev Date: Tue, 12 Mar 2024 18:34:02 +0200 Subject: [PATCH 01/13] Move DagsHubFilesystem to use RepoAPI to download files --- dagshub/common/api/repo.py | 48 +++++++++-- dagshub/streaming/filesystem.py | 144 +++++++++----------------------- tests/dda/mock_api.py | 4 +- 3 files changed, 80 insertions(+), 116 deletions(-) diff --git a/dagshub/common/api/repo.py b/dagshub/common/api/repo.py index 551a8fc3..45b7ac16 100644 --- a/dagshub/common/api/repo.py +++ b/dagshub/common/api/repo.py @@ -3,6 +3,8 @@ from pathlib import Path, PurePosixPath import rich.progress +from httpx import Response +from tenacity import retry_if_result, stop_after_attempt, wait_exponential, before_sleep_log, retry, retry_if_exception from dagshub.common.api.responses import ( RepoAPIResponse, @@ -54,6 +56,22 @@ class PathNotFoundError(Exception): pass +class DagsHubHTTPError(Exception): + def __init__(self, msg: str, response: Response): + super().__init__() + self.msg = msg + self.response = response + + def __str__(self): + return self.msg + + +def _is_server_error_exception(exception: BaseException) -> bool: + if not isinstance(exception, DagsHubHTTPError): + return False + return exception.response.status_code >= 500 + + class RepoAPI: def __init__(self, repo: str, host: Optional[str] = None, auth: Optional[Any] = None): """ @@ -89,7 +107,7 @@ def get_repo_info(self) -> RepoAPIResponse: error_msg = f"Got status code {res.status_code} when getting repository info." logger.error(error_msg) logger.debug(res.content) - raise RuntimeError(error_msg) + raise DagsHubHTTPError(error_msg, res) return dacite.from_dict(RepoAPIResponse, res.json()) def get_branch_info(self, branch: str) -> BranchAPIResponse: @@ -107,7 +125,7 @@ def get_branch_info(self, branch: str) -> BranchAPIResponse: error_msg = f"Got status code {res.status_code} when getting branch." logger.error(error_msg) logger.debug(res.content) - raise RuntimeError(error_msg) + raise DagsHubHTTPError(error_msg, res) return dacite.from_dict(BranchAPIResponse, res.json()) @@ -126,7 +144,7 @@ def get_commit_info(self, sha: str) -> CommitAPIResponse: error_msg = f"Got status code {res.status_code} when getting commit." logger.error(error_msg) logger.debug(res.content) - raise RuntimeError(error_msg) + raise DagsHubHTTPError(error_msg, res) return dacite.from_dict(CommitAPIResponse, res.json()["commit"]) @@ -142,7 +160,7 @@ def get_connected_storages(self) -> List[StorageAPIEntry]: error_msg = f"Got status code {res.status_code} when getting repository info." logger.error(error_msg) logger.debug(res.content) - raise RuntimeError(error_msg) + raise DagsHubHTTPError(error_msg, res) return [dacite.from_dict(StorageAPIEntry, storage_entry) for storage_entry in res.json()] @@ -164,7 +182,7 @@ def list_path(self, path: str, revision: Optional[str] = None, include_size: boo error_msg = f"Got status code {res.status_code} when listing path {path}" logger.error(error_msg) logger.debug(res.content) - raise RuntimeError(error_msg) + raise DagsHubHTTPError(error_msg, res) content = res.json() if type(content) is dict: @@ -194,7 +212,7 @@ def _get(): error_msg = f"Got status code {res.status_code} when listing path {path}" logger.error(error_msg) logger.debug(res.content) - raise RuntimeError(error_msg) + raise DagsHubHTTPError(error_msg, res) content = res.json() if "entries" not in content: @@ -218,6 +236,12 @@ def _get(): return entries + @retry( + retry=retry_if_result(_is_server_error_exception), + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) def get_file(self, path: str, revision: Optional[str] = None) -> bytes: """ Download file from repo. @@ -229,16 +253,22 @@ def get_file(self, path: str, revision: Optional[str] = None) -> bytes: Returns: bytes: The content of the file. """ - res = self._http_request("GET", self.raw_api_url(path, revision)) + res = self._http_request("GET", self.raw_api_url(path, revision), timeout=None) if res.status_code == 404: raise PathNotFoundError(f"Path {path} not found") elif res.status_code >= 400: error_msg = f"Got status code {res.status_code} when getting file {path}" logger.error(error_msg) logger.debug(res.content) - raise RuntimeError(error_msg) + raise DagsHubHTTPError(error_msg, res) return res.content + @retry( + retry=retry_if_result(_is_server_error_exception), + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) def get_storage_file(self, path: str) -> bytes: """ Download file from a connected storage bucket. @@ -258,7 +288,7 @@ def get_storage_file(self, path: str) -> bytes: error_msg = f"Got status code {res.status_code} when getting file {path}" logger.error(error_msg) logger.debug(res.content) - raise RuntimeError(error_msg) + raise DagsHubHTTPError(error_msg, res) return res.content def _get_files_in_path( diff --git a/dagshub/streaming/filesystem.py b/dagshub/streaming/filesystem.py index b5f40694..2e1bcf06 100644 --- a/dagshub/streaming/filesystem.py +++ b/dagshub/streaming/filesystem.py @@ -11,17 +11,15 @@ from multiprocessing import AuthenticationError from os import PathLike from pathlib import Path, PurePosixPath -from typing import Optional, TypeVar, Union, Dict, Set, Tuple, List, Any, Callable +from typing import Optional, TypeVar, Union, Dict, Set, Tuple, List, Callable from urllib.parse import urlparse, ParseResult -import dacite -from httpx import Response -from tenacity import retry, retry_if_result, stop_after_attempt, wait_exponential, before_sleep_log, RetryError +from tenacity import RetryError from dagshub.common import config, is_inside_notebook, is_inside_colab -from dagshub.common.api.repo import RepoAPI, CommitNotFoundError -from dagshub.common.api.responses import ContentAPIEntry, StorageContentAPIResult -from dagshub.common.helpers import http_request, get_project_root +from dagshub.common.api.repo import RepoAPI, CommitNotFoundError, PathNotFoundError, DagsHubHTTPError +from dagshub.common.api.responses import ContentAPIEntry +from dagshub.common.helpers import get_project_root from dagshub.streaming.dataclasses import DagshubPath from dagshub.streaming.errors import FilesystemAlreadyMountedError @@ -71,9 +69,6 @@ def __exit__(self, *args): SPECIAL_FILE = Path(".dagshub-streaming") -def _is_server_error(resp: Response): - return resp.status_code >= 500 - # TODO: Singleton metaclass that lets us keep a "main" DvcFilesystem instance class DagsHubFilesystem: @@ -257,30 +252,6 @@ def is_subpath(a: Path, b: Path) -> bool: DagsHubFilesystem.already_mounted_filesystems[self.project_root] = self - def get_remote_branch_head(self, branch): - """ - Get the head commit ID of a remote branch. - - Args: - branch (str): The name of the remote branch. - - Raises: - RuntimeError: Raised if there is an issue when trying to get the head of the branch. - - Returns: - str: The commit ID of the head of the remote branch. - - :meta private: - """ - url = self.get_api_url(f"/api/v1/repos{self.parsed_repo_url.path}/branches/{branch}") - resp = self.http_get(url) - if resp.status_code != 200: - raise RuntimeError( - f"Got status {resp.status_code} while trying to get head of branch {branch}. \r\n" - f"Response body: {resp.content}" - ) - return resp.json()["commit"]["id"] - @property def auth(self): import dagshub.auth @@ -392,22 +363,16 @@ def open(self, file, mode="r", buffering=-1, encoding=None, errors=None, newline # Open for reading - try to download the file if "r" in mode: try: - resp = self._api_download_file_git(path) + contents = self._api_download_file_git(path) except RetryError: raise RuntimeError(f"Couldn't download {path.relative_path} after multiple attempts") - if resp.status_code < 400: - self._mkdirs(path.absolute_path.parent) - # TODO: Handle symlinks - with self.__open(path.absolute_path, "wb") as output: - output.write(resp.content) - return self.__open(path.absolute_path, mode, buffering, encoding, errors, newline, closefd) - elif resp.status_code == 404: + except PathNotFoundError: raise FileNotFoundError(f"Error finding {path.relative_path} in repo or on DagsHub") - else: - raise RuntimeError( - f"Got response code {resp.status_code} from DagsHub while downloading file" - f" {path.relative_path}" - ) + self._mkdirs(path.absolute_path.parent) + # TODO: Handle symlinks + with self.__open(path.absolute_path, "wb") as output: + output.write(contents) + return self.__open(path.absolute_path, mode, buffering, encoding, errors, newline, closefd) # Write modes - make sure that the folder is a tracked folder (create if doesn't exist on disk), # and then let the user write to file else: @@ -419,12 +384,13 @@ def open(self, file, mode="r", buffering=-1, encoding=None, errors=None, newline # Try to download the file if we're in append modes if "a" in mode or "+" in mode: try: - resp = self._api_download_file_git(path) + contents = self._api_download_file_git(path) except RetryError: raise RuntimeError(f"Couldn't download {path.relative_path} after multiple attempts") - if resp.status_code < 400: - with self.__open(path.absolute_path, "wb") as output: - output.write(resp.content) + except PathNotFoundError: + raise FileNotFoundError(f"Error finding {path.relative_path} in repo or on DagsHub") + with self.__open(path.absolute_path, "wb") as output: + output.write(contents) return self.__open(path.absolute_path, mode, buffering, encoding, errors, newline, closefd) else: @@ -705,48 +671,24 @@ def generate_entry(path, is_directory): return res def _api_listdir(self, path: DagshubPath, include_size: bool = False) -> Optional[List[ContentAPIEntry]]: - response, hit = self._check_listdir_cache(path.relative_path.as_posix(), include_size) + assert path.relative_path is not None + + repo_path = path.relative_path.as_posix() + response, hit = self._check_listdir_cache(repo_path, include_size) if hit: return response - params: Dict[str, Any] = {"include_size": "true"} if include_size else {} - if path.is_storage_path: - params["paging"] = True - url = self._content_url_for_path(path) - - def _get() -> Optional[Response]: - resp = self.http_get(url, params=params, headers=config.requests_headers) - if resp.status_code == 404: - logger.debug(f"Got HTTP code {resp.status_code} while listing {path}, no results will be returned") - return None - elif resp.status_code >= 400: - logger.warning(f"Got HTTP code {resp.status_code} while listing {path}, no results will be returned") - return None - return resp - - response = _get() - if response is None: + + res: List[ContentAPIEntry] + try: + if path.is_storage_path: + storage_path = repo_path[len(".dagshub/storage/"):] + res = self._api.list_storage_path(storage_path, include_size=include_size) + else: + res = self._api.list_path(repo_path, self._current_revision, include_size=include_size) + except (PathNotFoundError, DagsHubHTTPError): return None - res: List[ContentAPIEntry] = [] - # Storage - token pagination, different return structure + if there's a token we do another request - if path.is_storage_path: - result = dacite.from_dict(StorageContentAPIResult, response.json()) - res += result.entries - while result.next_token is not None: - params["from_token"] = result.next_token - new_resp = _get() - if new_resp is None: - return None - result = dacite.from_dict(StorageContentAPIResult, new_resp.json()) - res += result.entries - else: - for entry_raw in response.json(): - entry = dacite.from_dict(ContentAPIEntry, entry_raw) - # Ignore storage root entries, we handle them separately in a different place - if entry.type == "storage": - continue - res.append(entry) - self._listdir_cache[path.relative_path.as_posix()] = (res, include_size) + self._listdir_cache[repo_path] = (res, include_size) return res def _check_listdir_cache(self, path: str, include_size: bool) -> Tuple[Optional[List[ContentAPIEntry]], bool]: @@ -776,22 +718,14 @@ def _raw_url_for_path(self, path: DagshubPath): return self._api.storage_raw_api_url(path_to_access) return self._api.raw_api_url(str_path, self._current_revision) - @retry( - retry=retry_if_result(_is_server_error), - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - before_sleep=before_sleep_log(logger, logging.WARNING), - ) - def _api_download_file_git(self, path: DagshubPath): - resp = self.http_get(self._raw_url_for_path(path), headers=config.requests_headers, timeout=None) - return resp - - def http_get(self, path: str, **kwargs): - timeout = self.timeout - if "timeout" in kwargs: - timeout = kwargs["timeout"] - del kwargs["timeout"] - return http_request("GET", path, auth=self.auth, timeout=timeout, **kwargs) + def _api_download_file_git(self, path: DagshubPath) -> bytes: + if path.relative_path is None: + raise RuntimeError(f"Can't access path {path.absolute_path} outside of repo") + str_path = path.relative_path.as_posix() + if path.is_storage_path: + str_path = str_path[len(".dagshub/storage/"):] + return self._api.get_storage_file(str_path) + return self._api.get_file(str_path, self._current_revision) def install_hooks(self): """ diff --git a/tests/dda/mock_api.py b/tests/dda/mock_api.py index 336f6138..6a1667de 100644 --- a/tests/dda/mock_api.py +++ b/tests/dda/mock_api.py @@ -62,7 +62,7 @@ def _default_endpoints_and_responses(self): "repo": rf"{self.repoapipath}/?$", "branch": rf"{self.repoapipath}/branches/(main|master)$", "branches": rf"{self.repoapipath}/branches/?$", - "list_root": rf"{self.repoapipath}/content/{self.current_revision}/$", + "list_root": rf"{self.repoapipath}/content/{self.current_revision}/(\?include_size=(false|true))?$", "storages": rf"{self.repoapipath}/storage/?$", } @@ -250,7 +250,7 @@ def add_storage_dir(self, path, contents=[], from_token=None, next_token=None, s Add a directory to the storage api Storage has a different response schema """ - url = f"{self.api_storage_list_path}/{path}?paging=true" + url = f"{self.api_storage_list_path}/{path}?include_size=false&paging=true" if from_token is not None: url += f"&from_token={from_token}" route = self.route(url=url) From 7079a6182fe91e14a872bd49be7926888434ea07 Mon Sep 17 00:00:00 2001 From: KBolashev Date: Wed, 13 Mar 2024 16:19:35 +0200 Subject: [PATCH 02/13] Add Ruff Do a bunch of fixes in the filesystem.py related to MyPy and Ruff --- dagshub/streaming/dataclasses.py | 2 +- dagshub/streaming/filesystem.py | 101 +++++++++++++----------- dagshub/streaming/hook_router.py | 12 +++ ruff.toml | 5 ++ tests/dda/filesystem/test_multihooks.py | 80 +++++++++++++++++++ tests/mocks/repo_api.py | 13 ++- 6 files changed, 165 insertions(+), 48 deletions(-) create mode 100644 dagshub/streaming/hook_router.py create mode 100644 ruff.toml create mode 100644 tests/dda/filesystem/test_multihooks.py diff --git a/dagshub/streaming/dataclasses.py b/dagshub/streaming/dataclasses.py index f205afbe..4c576151 100644 --- a/dagshub/streaming/dataclasses.py +++ b/dagshub/streaming/dataclasses.py @@ -8,7 +8,7 @@ from cached_property import cached_property if TYPE_CHECKING: - from dagshub.streaming import DagsHubFilesystem + from filesystem import DagsHubFilesystem storage_schemas = ["s3", "gs", "azure"] diff --git a/dagshub/streaming/filesystem.py b/dagshub/streaming/filesystem.py index 2e1bcf06..26ea5d52 100644 --- a/dagshub/streaming/filesystem.py +++ b/dagshub/streaming/filesystem.py @@ -27,7 +27,7 @@ # In 3.11 _NormalAccessor was removed PRE_PYTHON3_11 = sys.version_info.major == 3 and sys.version_info.minor < 11 if PRE_PYTHON3_11: - from pathlib import _NormalAccessor as _pathlib # noqa: E402 + from pathlib import _NormalAccessor as _pathlib # noqa try: from functools import cached_property @@ -49,7 +49,7 @@ def wrapper(*args, **kwargs): return decorator -class dagshub_ScandirIterator: +class DagshubScandirIterator: def __init__(self, iterator): self._iterator = iterator @@ -69,7 +69,6 @@ def __exit__(self, *args): SPECIAL_FILE = Path(".dagshub-streaming") - # TODO: Singleton metaclass that lets us keep a "main" DvcFilesystem instance class DagsHubFilesystem: """ @@ -184,7 +183,7 @@ def _current_revision(self) -> str: """ Gets current revision on repo: - If User specified a branch, returns HEAD of that brunch on the remote - - If branch wasn't detected, returns HEAD of default branch in the speficied remote. + - If branch wasn't detected, returns HEAD of default branch in the specified remote. - If HEAD is a branch, tries to find a dagshub remote associated with it and get its HEAD - If HEAD is a commit revision, checks that the commit exists on DagsHub """ @@ -287,6 +286,7 @@ def get_remotes_from_git_config(self) -> List[str]: for remote in git_remotes: if remote.hostname != config.hostname: continue + assert remote.hostname is not None remote = remote._replace(netloc=remote.hostname) remote = remote._replace(path=re.compile(r"(\.git)?/?$").sub("", remote.path)) res_remotes.append(remote.geturl()) @@ -300,10 +300,12 @@ def cleanup(self): if hasattr(self, "project_root") and self.project_root in DagsHubFilesystem.already_mounted_filesystems: DagsHubFilesystem.already_mounted_filesystems.pop(self.project_root) - def _parse_path(self, file: Union[str, PathLike, int]) -> DagshubPath: + def _parse_path(self, file: Union[str, bytes, PathLike, DagshubPath]) -> DagshubPath: + if isinstance(file, DagshubPath): + return file + if isinstance(file, bytes): + file = os.fsdecode(file) orig_path = Path(file) - if isinstance(file, int): - return DagshubPath(self, None, None, orig_path) if file == "": return DagshubPath(self, None, None, orig_path) abspath = Path(os.path.abspath(file)) @@ -315,11 +317,22 @@ def _parse_path(self, file: Union[str, PathLike, int]) -> DagshubPath: except ValueError: return DagshubPath(self, abspath, None, orig_path) - def _special_file(self): + @staticmethod + def _special_file(): # TODO Include more information in this file return b"v0\n" - def open(self, file, mode="r", buffering=-1, encoding=None, errors=None, newline=None, closefd=True, opener=None): + def open( + self, + file: Union[str, int, bytes, PathLike, DagshubPath], + mode="r", + buffering=-1, + encoding=None, + errors=None, + newline=None, + closefd=True, + opener=None, + ): """ NOTE: This is a wrapper function for python's built-in file operations (https://docs.python.org/3/library/functions.html#open) @@ -343,13 +356,15 @@ def open(self, file, mode="r", buffering=-1, encoding=None, errors=None, newline :meta private: """ # FD passthrough - if type(file) is int: + if isinstance(file, int): return self.__open(file, mode, buffering, encoding, errors, newline, closefd) - if type(file) is bytes: + if isinstance(file, bytes): file = os.fsdecode(file) path = self._parse_path(file) if path.is_in_repo: + assert path.relative_path is not None + assert path.absolute_path is not None if opener is not None: raise NotImplementedError("DagsHub's patched open() does not support custom openers") if path.is_passthrough_path: @@ -373,7 +388,7 @@ def open(self, file, mode="r", buffering=-1, encoding=None, errors=None, newline with self.__open(path.absolute_path, "wb") as output: output.write(contents) return self.__open(path.absolute_path, mode, buffering, encoding, errors, newline, closefd) - # Write modes - make sure that the folder is a tracked folder (create if doesn't exist on disk), + # Write modes - make sure that the folder is a tracked folder (create if it doesn't exist on disk), # and then let the user write to file else: try: @@ -396,7 +411,7 @@ def open(self, file, mode="r", buffering=-1, encoding=None, errors=None, newline else: return self.__open(file, mode, buffering, encoding, errors, newline, closefd, opener) - def os_open(self, path, flags, mode=0o777, *, dir_fd=None): + def os_open(self, path: Union[str, bytes, PathLike, DagshubPath], flags, mode=0o777, *, dir_fd=None): """ os.open is supposed to be lower level, but it's still being used by e.g. Pathlib We're trying to wrap around it here, by parsing flags and calling the higher-level open @@ -411,23 +426,25 @@ def os_open(self, path, flags, mode=0o777, *, dir_fd=None): if dir_fd is not None: # If dir_fd supplied, path is relative to that dir's fd, will handle in the future logger.debug("fs.os_open - NotImplemented") raise NotImplementedError("DagsHub's patched os.open() (for pathlib only) does not support dir_fd") - path = self._parse_path(path) - if path.is_in_repo: + dh_path = self._parse_path(path) + if dh_path.is_in_repo: + assert dh_path.absolute_path is not None + assert dh_path.relative_path is not None try: open_mode = "r" # Write modes - calling in append mode, # This way we create the intermediate folders if file doesn't exist, but the folder it's in does - # Append so we don't truncate the file + # Append, so we don't truncate the file if not (flags & os.O_RDONLY): open_mode = "a" logger.debug("fs.os_open - trying to materialize path") - self.open(path.absolute_path, mode=open_mode).close() + self.open(dh_path.absolute_path, mode=open_mode).close() logger.debug("fs.os_open - successfully materialized path") except FileNotFoundError: logger.debug("fs.os_open - failed to materialize path, os.open will throw") - return os.open(path.absolute_path, flags, mode, dir_fd=dir_fd) + return os.open(dh_path.absolute_path, flags, mode, dir_fd=dir_fd) - def stat(self, path, *args, dir_fd=None, follow_symlinks=True): + def stat(self, path: Union[str, int, bytes, PathLike], *args, dir_fd=None, follow_symlinks=True): """ NOTE: This is a wrapper function for python's built-in file operations (https://docs.python.org/3/library/os.html#os.stat) @@ -435,7 +452,7 @@ def stat(self, path, *args, dir_fd=None, follow_symlinks=True): Get the status of a file or directory, including support for special files and DagsHub integration. Args: - path (Union[str, int, bytes]): The path of the file or directory to get the status for. + path: The path of the file or directory to get the status for. It can be a path (str), file descriptor (int), or bytes-like object. dir_fd (int, optional): File descriptor of the directory. Defaults to None. follow_symlinks (bool, optional): Whether to follow symbolic links. Defaults to True. @@ -446,10 +463,10 @@ def stat(self, path, *args, dir_fd=None, follow_symlinks=True): :meta private: """ # FD passthrough - if type(path) is int: + if isinstance(path, int): return self.__stat(path, *args, dir_fd=dir_fd, follow_symlinks=follow_symlinks) - if type(path) is bytes: + if isinstance(path, bytes): path = os.fsdecode(path) if dir_fd is not None or not follow_symlinks: logger.debug("fs.stat - NotImplemented") @@ -457,11 +474,13 @@ def stat(self, path, *args, dir_fd=None, follow_symlinks=True): parsed_path = self._parse_path(path) # todo: remove False if parsed_path.is_in_repo: + assert parsed_path.relative_path is not None + assert parsed_path.absolute_path is not None logger.debug("fs.stat - is relative path") if parsed_path.is_passthrough_path: return self.__stat(parsed_path.absolute_path) elif parsed_path.relative_path == SPECIAL_FILE: - return dagshub_stat_result(self, path, is_directory=False, custom_size=len(self._special_file())) + return DagshubStatResult(self, parsed_path, is_directory=False, custom_size=len(self._special_file())) else: try: logger.debug(f"fs.stat - calling __stat - relative_path: {path}") @@ -488,7 +507,7 @@ def stat(self, path, *args, dir_fd=None, follow_symlinks=True): raise err if filetype == "file": - return dagshub_stat_result(self, path, is_directory=False) + return DagshubStatResult(self, parsed_path, is_directory=False) elif filetype == "dir": self._mkdirs(parsed_path.absolute_path) return self.__stat(parsed_path.absolute_path) @@ -512,10 +531,10 @@ def chdir(self, path): :meta private: """ # FD check - if type(path) is int: + if isinstance(path, int): return self.__chdir(path) - if type(path) is bytes: + if isinstance(path, bytes): path = os.fsdecode(path) parsed_path = self._parse_path(path) if parsed_path.is_in_repo: @@ -550,11 +569,11 @@ def listdir(self, path="."): :meta private: """ # FD check - if type(path) is int: + if isinstance(path, int): return self.__listdir(path) # listdir needs to return results for bytes path arg also in bytes - is_bytes_path_arg = type(path) is bytes + is_bytes_path_arg = isinstance(path, bytes) def encode_results(res): res = list(res) @@ -605,15 +624,15 @@ def encode_results(res): def project_root_dagshub_path(self): return DagshubPath(absolute_path=self.project_root, relative_path=Path(), original_path=Path(), fs=self) - @wrapreturn(dagshub_ScandirIterator) + @wrapreturn(DagshubScandirIterator) def scandir(self, path="."): # FD check - if type(path) is int: + if isinstance(path, int): for direntry in self.__scandir(path): yield direntry return # scandir needs to return name and path as bytes, if entry arg is bytes - is_bytes_path_arg = type(path) is bytes + is_bytes_path_arg = isinstance(path, bytes) if is_bytes_path_arg: str_path = os.fsdecode(path) else: @@ -639,18 +658,18 @@ def scandir(self, path="."): for f in resp: name = PurePosixPath(f.path).name if name not in local_filenames: - yield dagshub_DirEntry(self, parsed_path / name, f.type == "dir", is_binary=is_bytes_path_arg) + yield DagshubDirEntry(self, parsed_path / name, f.type == "dir", is_binary=is_bytes_path_arg) else: for entry in self.__scandir(path): yield entry def _get_special_paths( self, dh_path: DagshubPath, relative_to: DagshubPath, is_binary: bool - ) -> Set["dagshub_DirEntry"]: + ) -> Set["DagshubDirEntry"]: def generate_entry(path, is_directory): if isinstance(path, str): path = Path(path) - return dagshub_DirEntry(self, relative_to / path, is_directory=is_directory, is_binary=is_binary) + return DagshubDirEntry(self, relative_to / path, is_directory=is_directory, is_binary=is_binary) has_storages = len(self._storages) > 0 res = set() @@ -681,7 +700,7 @@ def _api_listdir(self, path: DagshubPath, include_size: bool = False) -> Optiona res: List[ContentAPIEntry] try: if path.is_storage_path: - storage_path = repo_path[len(".dagshub/storage/"):] + storage_path = repo_path[len(".dagshub/storage/") :] res = self._api.list_storage_path(storage_path, include_size=include_size) else: res = self._api.list_path(repo_path, self._current_revision, include_size=include_size) @@ -723,7 +742,7 @@ def _api_download_file_git(self, path: DagshubPath) -> bytes: raise RuntimeError(f"Can't access path {path.absolute_path} outside of repo") str_path = path.relative_path.as_posix() if path.is_storage_path: - str_path = str_path[len(".dagshub/storage/"):] + str_path = str_path[len(".dagshub/storage/") :] return self._api.get_storage_file(str_path) return self._api.get_file(str_path, self._current_revision) @@ -996,7 +1015,7 @@ def uninstall_hooks(): DagsHubFilesystem.uninstall_hooks() -class dagshub_stat_result: +class DagshubStatResult: def __init__(self, fs: "DagsHubFilesystem", path: DagshubPath, is_directory: bool, custom_size: int = None): self._fs = fs self._path = path @@ -1030,7 +1049,7 @@ def __repr__(self): return f"dagshub_stat_result({inner}, path={self._path})" -class dagshub_DirEntry: +class DagshubDirEntry: def __init__(self, fs: "DagsHubFilesystem", path: DagshubPath, is_directory: bool = False, is_binary: bool = False): self._fs = fs self._path = path @@ -1097,10 +1116,4 @@ def __repr__(self): return f"" -# Used for testing purposes only -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - fs = DagsHubFilesystem() - fs.install_hooks() - __all__ = [DagsHubFilesystem.__name__, install_hooks.__name__] diff --git a/dagshub/streaming/hook_router.py b/dagshub/streaming/hook_router.py new file mode 100644 index 00000000..8ca20b61 --- /dev/null +++ b/dagshub/streaming/hook_router.py @@ -0,0 +1,12 @@ +from os import PathLike +from typing import Union, Optional + +from dagshub.streaming import DagsHubFilesystem + + +class HookRouter: + def install_hooks(self, fs: DagsHubFilesystem): + pass + + def uninstall_hooks(self, fs: Optional[DagsHubFilesystem]=None, path: Optional[Union[str, PathLike]]=None): + pass diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 00000000..e414d1ae --- /dev/null +++ b/ruff.toml @@ -0,0 +1,5 @@ +line-length=120 + +[lint] +select = ["E", "F"] +ignore = ["E111", "E203", "E114", "E117", "E701"] diff --git a/tests/dda/filesystem/test_multihooks.py b/tests/dda/filesystem/test_multihooks.py new file mode 100644 index 00000000..b51177b1 --- /dev/null +++ b/tests/dda/filesystem/test_multihooks.py @@ -0,0 +1,80 @@ +from pathlib import Path +from unittest.mock import patch + +import pytest + +from dagshub.streaming import DagsHubFilesystem, uninstall_hooks +from tests.mocks.repo_api import MockRepoAPI + + +@pytest.fixture +def username(): + return "user" + + +@pytest.fixture +def repo_1_name(): + return "repo1" + + +@pytest.fixture +def repo_2_name(): + return "repo2" + + +@pytest.fixture +def repo_1(username, repo_1_name) -> MockRepoAPI: + repo = MockRepoAPI(f"{username}/{repo_1_name}") + repo.add_repo_file("a/b.txt", b"content repo 1") + return repo + + +@pytest.fixture +def repo_2(username, repo_2_name) -> MockRepoAPI: + repo = MockRepoAPI(f"{username}/{repo_2_name}") + repo.add_repo_file("a/b.txt", b"content repo 2") + return repo + + +def mock_repo_api_patch(repo_api: MockRepoAPI): + def mocked(_self: DagsHubFilesystem, _path): + return repo_api + + return mocked + + +def generate_mock_fs(repo_api: MockRepoAPI, file_dir: Path) -> DagsHubFilesystem: + with patch("dagshub.streaming.DagsHubFilesystem._generate_repo_api", mock_repo_api_patch(repo_api)): + fs = DagsHubFilesystem(project_root=file_dir, repo_url="https://localhost.invalid") + return fs + + +def test_mock_fs_works(repo_1, tmp_path): + fs = generate_mock_fs(repo_1, tmp_path) + assert fs.open(tmp_path / "a/b.txt", "rb").read() == b"content repo 1" + pass + + +def test_two_mock_fs(repo_1, repo_2, tmp_path): + path1 = tmp_path / "repo1" + path2 = tmp_path / "repo2" + fs1 = generate_mock_fs(repo_1, path1) + fs2 = generate_mock_fs(repo_2, path2) + assert fs1.open(path1 / "a/b.txt", "rb").read() == b"content repo 1" + assert fs2.open(path2 / "a/b.txt", "rb").read() == b"content repo 2" + + +def test_install_hooks_two_fs(repo_1, repo_2, tmp_path): + path1 = tmp_path / "repo1" + path2 = tmp_path / "repo2" + fs1 = generate_mock_fs(repo_1, path1) + fs2 = generate_mock_fs(repo_2, path2) + + try: + fs1.install_hooks() + fs2.install_hooks() + + assert open(path1 / "a/b.txt", "rb").read() == b"content repo 1" + assert open(path2 / "a/b.txt", "rb").read() == b"content repo 2" + finally: + uninstall_hooks() diff --git a/tests/mocks/repo_api.py b/tests/mocks/repo_api.py index af8db654..df8e7980 100644 --- a/tests/mocks/repo_api.py +++ b/tests/mocks/repo_api.py @@ -8,8 +8,7 @@ from dagshub.common.api.responses import StorageAPIEntry, ContentAPIEntry, CommitAPIResponse -class MockError(Exception): - ... +class MockError(Exception): ... class MockRepoAPI(RepoAPI): @@ -118,6 +117,8 @@ def get_connected_storages(self) -> List[StorageAPIEntry]: return self.storages def get_file(self, path: str, revision: Optional[str] = None) -> bytes: + if path == ".": + path = "" if revision is None: revision = self.default_branch content = self.repo_files.get(revision, {}).get(path) @@ -126,12 +127,16 @@ def get_file(self, path: str, revision: Optional[str] = None) -> bytes: return content def get_storage_file(self, path: str) -> bytes: + if path == ".": + path = "" content = self.storage_files.get(path) if content is None: raise PathNotFoundError return content def list_path(self, path: str, revision: Optional[str] = None, include_size: bool = False) -> List[ContentAPIEntry]: + if path == ".": + path = "" if revision is None: revision = self.default_branch content = self.repo_contents.get(revision, {}).get(path) @@ -140,6 +145,8 @@ def list_path(self, path: str, revision: Optional[str] = None, include_size: boo return content def list_storage_path(self, path: str, include_size: bool = False) -> List[ContentAPIEntry]: + if path == ".": + path = "" content = self.storage_contents.get(path) if content is None: raise PathNotFoundError @@ -147,7 +154,7 @@ def list_storage_path(self, path: str, include_size: bool = False) -> List[Conte def last_commit(self, branch: Optional[str] = None) -> CommitAPIResponse: return CommitAPIResponse( - id="deadbeef", + id=branch if branch is not None else "main", message="random-commit", url="http://local.invalid/commit", author=None, From d2831744485b6ffe5587d063aa5e679faa49b582 Mon Sep 17 00:00:00 2001 From: KBolashev Date: Wed, 13 Mar 2024 18:13:51 +0200 Subject: [PATCH 03/13] Extract parse_path to DagshubPath --- dagshub/streaming/dataclasses.py | 50 +++++++++++++++++------- dagshub/streaming/filesystem.py | 64 +++++++++++++------------------ tests/dda/filesystem/test_misc.py | 7 ++-- 3 files changed, 66 insertions(+), 55 deletions(-) diff --git a/dagshub/streaming/dataclasses.py b/dagshub/streaming/dataclasses.py index 4c576151..3076bf7d 100644 --- a/dagshub/streaming/dataclasses.py +++ b/dagshub/streaming/dataclasses.py @@ -1,6 +1,8 @@ +import os from dataclasses import dataclass +from os import PathLike from pathlib import Path -from typing import Optional, TYPE_CHECKING +from typing import Optional, TYPE_CHECKING, Union, Tuple try: from functools import cached_property @@ -26,13 +28,30 @@ class DagshubPath: original_path (Path): Original path as it was accessed by the user """ - # TODO: this couples this class hard to the fs, need to decouple later - fs: "DagsHubFilesystem" # Actual type is DagsHubFilesystem, but imports are wonky - absolute_path: Optional[Path] - relative_path: Optional[Path] - original_path: Optional[Path] + def __init__(self, fs: "DagsHubFilesystem", file_path: Union[str, bytes, PathLike, "DagshubPath"]): + self.fs = fs + self.absolute_path, self.relative_path, self.original_path = self.parse_path(file_path) - def __post_init__(self): + def parse_path(self, file_path: Union[str, bytes, PathLike, "DagshubPath"]) -> Tuple[Path, Optional[Path], Path]: + print(self.fs.project_root) + if isinstance(file_path, DagshubPath): + if file_path.fs != self.fs: + relativized = DagshubPath(self.fs, file_path.absolute_path) + return relativized.absolute_path, relativized.relative_path, relativized.original_path + return file_path.absolute_path, file_path.relative_path, file_path.original_path + if isinstance(file_path, bytes): + file_path = os.fsdecode(file_path) + orig_path = Path(file_path) + abspath = Path(os.path.abspath(file_path)) + try: + relpath = abspath.relative_to(os.path.abspath(self.fs.project_root)) + if str(relpath).startswith("<"): + return abspath, None, orig_path + return abspath, relpath, orig_path + except ValueError: + return abspath, None, orig_path + + def handle_storages(self): # Handle storage paths - translate s3:/bla-bla to .dagshub/storage/s3/bla-bla if self.relative_path is not None: str_path = self.relative_path.as_posix() @@ -41,9 +60,11 @@ def __post_init__(self): str_path = str_path[len(storage_schema) + 2 :] self.relative_path = Path(".dagshub/storage") / storage_schema / str_path self.absolute_path = self.fs.project_root / self.relative_path + break @cached_property def name(self): + assert self.absolute_path is not None return self.absolute_path.name @cached_property @@ -56,10 +77,11 @@ def is_storage_path(self): Is path a storage path (stored in a bucket) Those paths are accessible via a path like `.dagshub/storage/s3/bucket/...` """ + if self.relative_path is None: + return False return self.relative_path.as_posix().startswith(".dagshub/storage") - @cached_property - def is_passthrough_path(self): + def is_passthrough_path(self, fs: "DagsHubFilesystem"): """ Is path a "passthrough" path A passthrough path is a path that the FS ignores when trying to look up if the file exists on DagsHub @@ -68,17 +90,17 @@ def is_passthrough_path(self): If you need to read with streaming from a .dvc folder (to read config for example), please pull the repo - Any /site-packages/ folder - if you have a venv in your repo, python will try to find packages there. """ + if self.relative_path is None: + return True str_path = self.relative_path.as_posix() if "/site-packages/" in str_path or str_path.endswith("/site-packages"): return True if str_path.startswith((".git/", ".dvc/")) or str_path in (".git", ".dvc"): return True - return any((self.relative_path.match(glob) for glob in self.fs.exclude_globs)) + return any((self.relative_path.match(glob) for glob in fs.exclude_globs)) def __truediv__(self, other): return DagshubPath( - absolute_path=self.absolute_path / other, - relative_path=self.relative_path / other, - original_path=self.original_path / other, - fs=self.fs, + self.fs, + self.original_path / other, ) diff --git a/dagshub/streaming/filesystem.py b/dagshub/streaming/filesystem.py index 26ea5d52..4ac9dde0 100644 --- a/dagshub/streaming/filesystem.py +++ b/dagshub/streaming/filesystem.py @@ -152,21 +152,22 @@ def __init__( self.token = token or config.token self.timeout = timeout or config.http_timeout + self.exclude_globs: List[str] if exclude_globs is None: - exclude_globs = [] - elif exclude_globs is str: - exclude_globs = [exclude_globs] - - self.exclude_globs: List[str] = exclude_globs + self.exclude_globs = [] + elif isinstance(exclude_globs, str): + self.exclude_globs = [exclude_globs] + else: + self.exclude_globs = exclude_globs - self._listdir_cache: Dict[str, Optional[Tuple[List[ContentAPIEntry], bool]]] = {} + self._listdir_cache: Dict[str, Tuple[Optional[List[ContentAPIEntry]], bool]] = {} self._api = self._generate_repo_api(self.parsed_repo_url) self.check_project_root_use() # Check that the repo is accessible by accessing the content root - response = self._api_listdir(DagshubPath(self, self.project_root, Path(), Path())) + response = self._api_listdir(DagshubPath(self, self.project_root)) if response is None: # TODO: Check .dvc/config{,.local} for credentials raise AuthenticationError("DagsHub credentials required, however none provided or discovered") @@ -264,7 +265,9 @@ def auth(self): logger.debug("Failed to perform OAuth in a non interactive shell") # Try to fetch credentials from the git credential file - proc = subprocess.run(["git", "credential", "fill"], input=f"url={self.repo_url}".encode(), capture_output=True) + proc = subprocess.run( + ["git", "credential", "fill"], input=f"url={self.parsed_repo_url.geturl()}".encode(), capture_output=True + ) answer = {line[: line.index("=")]: line[line.index("=") + 1 :] for line in proc.stdout.decode().splitlines()} if "username" in answer and "password" in answer: return answer["username"], answer["password"] @@ -300,23 +303,6 @@ def cleanup(self): if hasattr(self, "project_root") and self.project_root in DagsHubFilesystem.already_mounted_filesystems: DagsHubFilesystem.already_mounted_filesystems.pop(self.project_root) - def _parse_path(self, file: Union[str, bytes, PathLike, DagshubPath]) -> DagshubPath: - if isinstance(file, DagshubPath): - return file - if isinstance(file, bytes): - file = os.fsdecode(file) - orig_path = Path(file) - if file == "": - return DagshubPath(self, None, None, orig_path) - abspath = Path(os.path.abspath(file)) - try: - relpath = abspath.relative_to(os.path.abspath(self.project_root)) - if str(relpath).startswith("<"): - return DagshubPath(self, abspath, None, orig_path) - return DagshubPath(self, abspath, relpath, orig_path) - except ValueError: - return DagshubPath(self, abspath, None, orig_path) - @staticmethod def _special_file(): # TODO Include more information in this file @@ -361,13 +347,13 @@ def open( if isinstance(file, bytes): file = os.fsdecode(file) - path = self._parse_path(file) + path = DagshubPath(self, file) if path.is_in_repo: assert path.relative_path is not None assert path.absolute_path is not None if opener is not None: raise NotImplementedError("DagsHub's patched open() does not support custom openers") - if path.is_passthrough_path: + if path.is_passthrough_path(self): return self.__open(path.absolute_path, mode, buffering, encoding, errors, newline, closefd) elif path.relative_path == SPECIAL_FILE: return io.BytesIO(self._special_file()) @@ -426,9 +412,8 @@ def os_open(self, path: Union[str, bytes, PathLike, DagshubPath], flags, mode=0o if dir_fd is not None: # If dir_fd supplied, path is relative to that dir's fd, will handle in the future logger.debug("fs.os_open - NotImplemented") raise NotImplementedError("DagsHub's patched os.open() (for pathlib only) does not support dir_fd") - dh_path = self._parse_path(path) + dh_path = DagshubPath(self, path) if dh_path.is_in_repo: - assert dh_path.absolute_path is not None assert dh_path.relative_path is not None try: open_mode = "r" @@ -471,13 +456,13 @@ def stat(self, path: Union[str, int, bytes, PathLike], *args, dir_fd=None, follo if dir_fd is not None or not follow_symlinks: logger.debug("fs.stat - NotImplemented") raise NotImplementedError("DagsHub's patched stat() does not support dir_fd or follow_symlinks") - parsed_path = self._parse_path(path) + parsed_path = DagshubPath(self, path) # todo: remove False if parsed_path.is_in_repo: assert parsed_path.relative_path is not None assert parsed_path.absolute_path is not None logger.debug("fs.stat - is relative path") - if parsed_path.is_passthrough_path: + if parsed_path.is_passthrough_path(self): return self.__stat(parsed_path.absolute_path) elif parsed_path.relative_path == SPECIAL_FILE: return DagshubStatResult(self, parsed_path, is_directory=False, custom_size=len(self._special_file())) @@ -536,7 +521,7 @@ def chdir(self, path): if isinstance(path, bytes): path = os.fsdecode(path) - parsed_path = self._parse_path(path) + parsed_path = DagshubPath(self, path) if parsed_path.is_in_repo: try: self.__chdir(parsed_path.absolute_path) @@ -585,9 +570,9 @@ def encode_results(res): str_path = os.fsdecode(path) else: str_path = path - parsed_path = self._parse_path(str_path) + parsed_path = DagshubPath(self, str_path) if parsed_path.is_in_repo: - if parsed_path.is_passthrough_path: + if parsed_path.is_passthrough_path(self): return self.listdir(parsed_path.original_path) else: dircontents: Set[str] = set() @@ -622,7 +607,7 @@ def encode_results(res): @cached_property def project_root_dagshub_path(self): - return DagshubPath(absolute_path=self.project_root, relative_path=Path(), original_path=Path(), fs=self) + return DagshubPath(self, self.project_root) @wrapreturn(DagshubScandirIterator) def scandir(self, path="."): @@ -637,8 +622,8 @@ def scandir(self, path="."): str_path = os.fsdecode(path) else: str_path = path - parsed_path = self._parse_path(str_path) - if parsed_path.is_in_repo and not parsed_path.is_passthrough_path: + parsed_path = DagshubPath(self, str_path) + if parsed_path.is_in_repo and not parsed_path.is_passthrough_path(self): path = Path(str_path) local_filenames = set() try: @@ -704,7 +689,10 @@ def _api_listdir(self, path: DagshubPath, include_size: bool = False) -> Optiona res = self._api.list_storage_path(storage_path, include_size=include_size) else: res = self._api.list_path(repo_path, self._current_revision, include_size=include_size) - except (PathNotFoundError, DagsHubHTTPError): + except PathNotFoundError: + self._listdir_cache[repo_path] = (None, True) + return None + except DagsHubHTTPError: return None self._listdir_cache[repo_path] = (res, include_size) diff --git a/tests/dda/filesystem/test_misc.py b/tests/dda/filesystem/test_misc.py index 7ba36d32..6ac938bc 100644 --- a/tests/dda/filesystem/test_misc.py +++ b/tests/dda/filesystem/test_misc.py @@ -1,10 +1,10 @@ import os.path import secrets import tempfile +from pathlib import Path from unittest.mock import MagicMock import pytest -from pathlib import Path from dagshub.streaming import DagsHubFilesystem from dagshub.streaming.dataclasses import DagshubPath @@ -27,8 +27,9 @@ ) def test_passthrough_path(path, expected): fs_mock = MagicMock() - path = DagshubPath(fs_mock, Path(os.path.abspath(path)), Path(path), Path(path)) - actual = path.is_passthrough_path + fs_mock.project_root = Path(os.getcwd()) + path = DagshubPath(fs_mock, path) + actual = path.is_passthrough_path(fs_mock) assert actual == expected From b34e082c63940c21ee8ea2b8019d672faafcc72a Mon Sep 17 00:00:00 2001 From: KBolashev Date: Thu, 14 Mar 2024 12:12:16 +0200 Subject: [PATCH 04/13] Install hooks via a router works now --- dagshub/streaming/dataclasses.py | 134 +++++- dagshub/streaming/filesystem.py | 542 ++++-------------------- dagshub/streaming/hook_router.py | 318 +++++++++++++- dagshub/streaming/util.py | 12 + requirements-dev.txt | 4 + tests/dda/filesystem/test_multihooks.py | 47 +- tests/dda/test_tokens.py | 1 - 7 files changed, 578 insertions(+), 480 deletions(-) create mode 100644 dagshub/streaming/util.py diff --git a/dagshub/streaming/dataclasses.py b/dagshub/streaming/dataclasses.py index 3076bf7d..b97a6dc5 100644 --- a/dagshub/streaming/dataclasses.py +++ b/dagshub/streaming/dataclasses.py @@ -26,15 +26,18 @@ class DagshubPath: relative_path (Optional[Path]): Path relative to the root of the encapsulating FileSystem. If None, path is outside the FS original_path (Path): Original path as it was accessed by the user + is_binary_path_requested (bool): For functions like scandir and listdir that have + different behaviour whether user requested a string or a binary path """ def __init__(self, fs: "DagsHubFilesystem", file_path: Union[str, bytes, PathLike, "DagshubPath"]): self.fs = fs + self.is_binary_path_requested = isinstance(file_path, bytes) self.absolute_path, self.relative_path, self.original_path = self.parse_path(file_path) def parse_path(self, file_path: Union[str, bytes, PathLike, "DagshubPath"]) -> Tuple[Path, Optional[Path], Path]: - print(self.fs.project_root) if isinstance(file_path, DagshubPath): + self.is_binary_path_requested = file_path.is_binary_path_requested if file_path.fs != self.fs: relativized = DagshubPath(self.fs, file_path.absolute_path) return relativized.absolute_path, relativized.relative_path, relativized.original_path @@ -100,7 +103,134 @@ def is_passthrough_path(self, fs: "DagsHubFilesystem"): return any((self.relative_path.match(glob) for glob in fs.exclude_globs)) def __truediv__(self, other): - return DagshubPath( + new = DagshubPath( self.fs, self.original_path / other, ) + new.is_binary_path_requested = self.is_binary_path_requested + return new + + +class DagshubScandirIterator: + def __init__(self, iterator): + self._iterator = iterator + + def __iter__(self): + return self._iterator + + def __next__(self): + return self._iterator.__next__() + + def __enter__(self): + return self + + def __exit__(self, *args): + return self + + +class DagshubStatResult: + def __init__( + self, fs: "DagsHubFilesystem", path: DagshubPath, is_directory: bool, custom_size: Optional[int] = None + ): + self._fs = fs + self._path = path + self._is_directory = is_directory + self._custom_size = custom_size + self._true_stat: Optional[os.stat_result] = None + assert not self._is_directory # TODO make folder stats lazy? + + def __getattr__(self, name: str): + if not name.startswith("st_"): + raise AttributeError + if self._true_stat is not None: + return os.stat_result.__getattribute__(self._true_stat, name) + if name == "st_uid": + return os.getuid() + elif name == "st_gid": + return os.getgid() + elif name == "st_atime" or name == "st_mtime" or name == "st_ctime": + return 0 + elif name == "st_mode": + return 0o100644 + elif name == "st_size": + if self._custom_size is not None: + return self._custom_size + return 1100 # hardcoded size because size requests take a disproportionate amount of time + self._fs.open(self._path) + self._true_stat = self._fs.original_stat(self._path.absolute_path) + return os.stat_result.__getattribute__(self._true_stat, name) + + def __repr__(self): + inner = repr(self._true_stat) if self._true_stat is not None else "pending..." + return f"dagshub_stat_result({inner}, path={self._path})" + + +class DagshubDirEntry: + def __init__(self, fs: "DagsHubFilesystem", path: DagshubPath, is_directory: bool = False, is_binary: bool = False): + self._fs = fs + self._path = path + self._is_directory = is_directory + self._is_binary = is_binary + self._true_direntry: Optional[os.DirEntry] = None + + @property + def name(self): + if self._true_direntry is not None: + name = self._true_direntry.name + else: + name = self._path.name + return os.fsencode(name) if self._is_binary else name + + @property + def path(self): + if self._true_direntry is not None: + path = self._true_direntry.path + else: + path = str(self._path.original_path) + return os.fsencode(path) if self._is_binary else path + + def is_dir(self): + if self._true_direntry is not None: + return self._true_direntry.is_dir() + else: + return self._is_directory + + def is_file(self): + if self._true_direntry is not None: + return self._true_direntry.is_file() + else: + # TODO: Symlinks should return false + return not self._is_directory + + def stat(self): + if self._true_direntry is not None: + return self._true_direntry.stat() + else: + return self._fs.stat(self._path.original_path) + + def __getattr__(self, name: str): + if name == "_true_direntry": + raise AttributeError + if self._true_direntry is not None: + return os.DirEntry.__getattribute__(self._true_direntry, name) + + # Either create a dir, or download the file + if self._is_directory: + self._fs.mkdirs(self._path.absolute_path) + else: + self._fs.open(self._path.absolute_path) + + for direntry in self._fs.original_stat(self._path.original_path): + if direntry.name == self._path.name: + self._true_direntry = direntry + return os.DirEntry.__getattribute__(self._true_direntry, name) + else: + raise FileNotFoundError + + def __repr__(self): + cached = " (cached)" if self._true_direntry is not None else "" + return f"" + + +PathType = Union[str, int, bytes, PathLike] +PathTypeWithDagshubPath = Union[PathType, DagshubPath] diff --git a/dagshub/streaming/filesystem.py b/dagshub/streaming/filesystem.py index 4ac9dde0..7e8f32a0 100644 --- a/dagshub/streaming/filesystem.py +++ b/dagshub/streaming/filesystem.py @@ -1,33 +1,30 @@ -import builtins -import importlib import io import logging import os import re import subprocess -import sys from configparser import ConfigParser -from functools import wraps from multiprocessing import AuthenticationError from os import PathLike from pathlib import Path, PurePosixPath -from typing import Optional, TypeVar, Union, Dict, Set, Tuple, List, Callable +from typing import Optional, TypeVar, Union, Dict, Set, Tuple, List from urllib.parse import urlparse, ParseResult from tenacity import RetryError -from dagshub.common import config, is_inside_notebook, is_inside_colab +from dagshub.common import config from dagshub.common.api.repo import RepoAPI, CommitNotFoundError, PathNotFoundError, DagsHubHTTPError from dagshub.common.api.responses import ContentAPIEntry from dagshub.common.helpers import get_project_root -from dagshub.streaming.dataclasses import DagshubPath -from dagshub.streaming.errors import FilesystemAlreadyMountedError - -# Pre 3.11 - need to patch _NormalAccessor for _pathlib, because it pre-caches open and other functions. -# In 3.11 _NormalAccessor was removed -PRE_PYTHON3_11 = sys.version_info.major == 3 and sys.version_info.minor < 11 -if PRE_PYTHON3_11: - from pathlib import _NormalAccessor as _pathlib # noqa +from dagshub.streaming.dataclasses import ( + DagshubPath, + DagshubScandirIterator, + DagshubDirEntry, + DagshubStatResult, + PathTypeWithDagshubPath, +) +from dagshub.streaming.hook_router import HookRouter +from dagshub.streaming.util import wrapreturn try: from functools import cached_property @@ -38,34 +35,6 @@ logger = logging.getLogger(__name__) -def wrapreturn(wrappertype): - def decorator(func): - @wraps(func) - def wrapper(*args, **kwargs): - return wrappertype(func(*args, **kwargs)) - - return wrapper - - return decorator - - -class DagshubScandirIterator: - def __init__(self, iterator): - self._iterator = iterator - - def __iter__(self): - return self._iterator - - def __next__(self): - return self._iterator.__next__() - - def __enter__(self): - return self - - def __exit__(self, *args): - return self - - SPECIAL_FILE = Path(".dagshub-streaming") @@ -94,16 +63,6 @@ class DagsHubFilesystem: - ``transformers`` - patches ``safetensors`` """ - already_mounted_filesystems: Dict[Path, "DagsHubFilesystem"] = {} - hooked_instance: Optional["DagsHubFilesystem"] = None - - # Framework-specific override functions. - # These functions will be patched with a function that calls fs.open() before calling the original function - # Classes are marked by $, so if you need to change a static/class method, use module.$class.func - _framework_override_map: Dict[str, List[str]] = { - "transformers": ["safetensors.safe_open", "tokenizers.$Tokenizer.from_file"], - } - def __init__( self, project_root: Optional["PathLike | str"] = None, @@ -164,8 +123,6 @@ def __init__( self._api = self._generate_repo_api(self.parsed_repo_url) - self.check_project_root_use() - # Check that the repo is accessible by accessing the content root response = self._api_listdir(DagshubPath(self, self.project_root)) if response is None: @@ -232,26 +189,6 @@ def is_commit_on_remote(self, sha1): except CommitNotFoundError: return False - def check_project_root_use(self): - """ - Checks that there's no other filesystem being mounted at the current project root - If there is one, throw an error - - :meta private: - """ - - def is_subpath(a: Path, b: Path) -> bool: - # Checks if either a or b are subpaths of each other - a_str = a.as_posix() - b_str = b.as_posix() - return a_str.startswith(b_str) or b_str.startswith(a_str) - - for p, f in DagsHubFilesystem.already_mounted_filesystems.items(): - if is_subpath(p, self.project_root): - raise FilesystemAlreadyMountedError(self.project_root, f.parsed_repo_url.path[1:], f._current_revision) - - DagsHubFilesystem.already_mounted_filesystems[self.project_root] = self - @property def auth(self): import dagshub.auth @@ -295,14 +232,6 @@ def get_remotes_from_git_config(self) -> List[str]: res_remotes.append(remote.geturl()) return res_remotes - def __del__(self): - self.cleanup() - - def cleanup(self): - # Remove from map of mounted filesystems - if hasattr(self, "project_root") and self.project_root in DagsHubFilesystem.already_mounted_filesystems: - DagsHubFilesystem.already_mounted_filesystems.pop(self.project_root) - @staticmethod def _special_file(): # TODO Include more information in this file @@ -310,7 +239,7 @@ def _special_file(): def open( self, - file: Union[str, int, bytes, PathLike, DagshubPath], + file: PathTypeWithDagshubPath, mode="r", buffering=-1, encoding=None, @@ -343,7 +272,7 @@ def open( """ # FD passthrough if isinstance(file, int): - return self.__open(file, mode, buffering, encoding, errors, newline, closefd) + return self.original_open(file, mode, buffering, encoding, errors, newline, closefd) if isinstance(file, bytes): file = os.fsdecode(file) @@ -354,12 +283,12 @@ def open( if opener is not None: raise NotImplementedError("DagsHub's patched open() does not support custom openers") if path.is_passthrough_path(self): - return self.__open(path.absolute_path, mode, buffering, encoding, errors, newline, closefd) + return self.original_open(path.absolute_path, mode, buffering, encoding, errors, newline, closefd) elif path.relative_path == SPECIAL_FILE: return io.BytesIO(self._special_file()) else: try: - return self.__open(path.absolute_path, mode, buffering, encoding, errors, newline, closefd) + return self.original_open(path.absolute_path, mode, buffering, encoding, errors, newline, closefd) except FileNotFoundError as err: # Open for reading - try to download the file if "r" in mode: @@ -369,11 +298,13 @@ def open( raise RuntimeError(f"Couldn't download {path.relative_path} after multiple attempts") except PathNotFoundError: raise FileNotFoundError(f"Error finding {path.relative_path} in repo or on DagsHub") - self._mkdirs(path.absolute_path.parent) + self.mkdirs(path.absolute_path.parent) # TODO: Handle symlinks - with self.__open(path.absolute_path, "wb") as output: + with self.original_open(path.absolute_path, "wb") as output: output.write(contents) - return self.__open(path.absolute_path, mode, buffering, encoding, errors, newline, closefd) + return self.original_open( + path.absolute_path, mode, buffering, encoding, errors, newline, closefd + ) # Write modes - make sure that the folder is a tracked folder (create if it doesn't exist on disk), # and then let the user write to file else: @@ -390,12 +321,14 @@ def open( raise RuntimeError(f"Couldn't download {path.relative_path} after multiple attempts") except PathNotFoundError: raise FileNotFoundError(f"Error finding {path.relative_path} in repo or on DagsHub") - with self.__open(path.absolute_path, "wb") as output: + with self.original_open(path.absolute_path, "wb") as output: output.write(contents) - return self.__open(path.absolute_path, mode, buffering, encoding, errors, newline, closefd) + return self.original_open( + path.absolute_path, mode, buffering, encoding, errors, newline, closefd + ) else: - return self.__open(file, mode, buffering, encoding, errors, newline, closefd, opener) + return self.original_open(file, mode, buffering, encoding, errors, newline, closefd, opener) def os_open(self, path: Union[str, bytes, PathLike, DagshubPath], flags, mode=0o777, *, dir_fd=None): """ @@ -429,7 +362,7 @@ def os_open(self, path: Union[str, bytes, PathLike, DagshubPath], flags, mode=0o logger.debug("fs.os_open - failed to materialize path, os.open will throw") return os.open(dh_path.absolute_path, flags, mode, dir_fd=dir_fd) - def stat(self, path: Union[str, int, bytes, PathLike], *args, dir_fd=None, follow_symlinks=True): + def stat(self, path: PathTypeWithDagshubPath, *args, dir_fd=None, follow_symlinks=True): """ NOTE: This is a wrapper function for python's built-in file operations (https://docs.python.org/3/library/os.html#os.stat) @@ -449,7 +382,7 @@ def stat(self, path: Union[str, int, bytes, PathLike], *args, dir_fd=None, follo """ # FD passthrough if isinstance(path, int): - return self.__stat(path, *args, dir_fd=dir_fd, follow_symlinks=follow_symlinks) + return self.original_stat(path, *args, dir_fd=dir_fd, follow_symlinks=follow_symlinks) if isinstance(path, bytes): path = os.fsdecode(path) @@ -457,19 +390,18 @@ def stat(self, path: Union[str, int, bytes, PathLike], *args, dir_fd=None, follo logger.debug("fs.stat - NotImplemented") raise NotImplementedError("DagsHub's patched stat() does not support dir_fd or follow_symlinks") parsed_path = DagshubPath(self, path) - # todo: remove False if parsed_path.is_in_repo: assert parsed_path.relative_path is not None assert parsed_path.absolute_path is not None logger.debug("fs.stat - is relative path") if parsed_path.is_passthrough_path(self): - return self.__stat(parsed_path.absolute_path) + return self.original_stat(parsed_path.absolute_path) elif parsed_path.relative_path == SPECIAL_FILE: return DagshubStatResult(self, parsed_path, is_directory=False, custom_size=len(self._special_file())) else: try: logger.debug(f"fs.stat - calling __stat - relative_path: {path}") - return self.__stat(parsed_path.absolute_path) + return self.original_stat(parsed_path.absolute_path) except FileNotFoundError as err: logger.debug("fs.stat - FileNotFoundError") logger.debug(f"remote_tree: {self.remote_tree}") @@ -494,15 +426,14 @@ def stat(self, path: Union[str, int, bytes, PathLike], *args, dir_fd=None, follo if filetype == "file": return DagshubStatResult(self, parsed_path, is_directory=False) elif filetype == "dir": - self._mkdirs(parsed_path.absolute_path) - return self.__stat(parsed_path.absolute_path) + self.mkdirs(parsed_path.absolute_path) + return self.original_stat(parsed_path.absolute_path) else: raise RuntimeError(f"Unknown file type {filetype} for path {str(parsed_path)}") - else: - return self.__stat(path, follow_symlinks=follow_symlinks) + return self.original_stat(path, follow_symlinks=follow_symlinks) - def chdir(self, path): + def chdir(self, path: PathTypeWithDagshubPath): """ NOTE: This is a wrapper function for python's built-in file operations (https://docs.python.org/3/library/os.html#os.chdir) @@ -517,26 +448,26 @@ def chdir(self, path): """ # FD check if isinstance(path, int): - return self.__chdir(path) + return self.original_chdir(path) if isinstance(path, bytes): path = os.fsdecode(path) parsed_path = DagshubPath(self, path) if parsed_path.is_in_repo: try: - self.__chdir(parsed_path.absolute_path) + self.original_chdir(parsed_path.absolute_path) except FileNotFoundError: resp = self._api_listdir(parsed_path) # FIXME: if path is file, return FileNotFound instead of the listdir error if resp is not None: - self._mkdirs(parsed_path.absolute_path) - self.__chdir(parsed_path.absolute_path) + self.mkdirs(parsed_path.absolute_path) + self.original_chdir(parsed_path.absolute_path) else: raise else: - self.__chdir(path) + self.original_chdir(path) - def listdir(self, path="."): + def listdir(self, path: PathTypeWithDagshubPath = "."): """ NOTE: This is a wrapper function for python's built-in file operations (https://docs.python.org/3/library/os.html#os.listdir) @@ -555,22 +486,17 @@ def listdir(self, path="."): """ # FD check if isinstance(path, int): - return self.__listdir(path) + return self.original_listdir(path) - # listdir needs to return results for bytes path arg also in bytes - is_bytes_path_arg = isinstance(path, bytes) + parsed_path = DagshubPath(self, path) + # listdir needs to return results for bytes path arg also in bytes def encode_results(res): res = list(res) - if is_bytes_path_arg: + if parsed_path.is_binary_path_requested: res = [os.fsencode(p) for p in res] return res - if is_bytes_path_arg: - str_path = os.fsdecode(path) - else: - str_path = path - parsed_path = DagshubPath(self, str_path) if parsed_path.is_in_repo: if parsed_path.is_passthrough_path(self): return self.listdir(parsed_path.original_path) @@ -578,16 +504,17 @@ def encode_results(res): dircontents: Set[str] = set() error = None try: - dircontents.update(self.__listdir(parsed_path.original_path)) + dircontents.update(self.original_listdir(parsed_path.original_path)) except FileNotFoundError as e: error = e dircontents.update( special.name for special in self._get_special_paths( - parsed_path, self.project_root_dagshub_path, is_bytes_path_arg + parsed_path, self.project_root_dagshub_path, parsed_path.is_binary_path_requested ) ) # If we're accessing .dagshub/storage/s3/ we don't need to access the API, return straight away + assert parsed_path.relative_path is not None len_parts = len(parsed_path.relative_path.parts) if 0 < len_parts <= 3 and parsed_path.relative_path.parts[0] == ".dagshub": return encode_results(dircontents) @@ -603,37 +530,33 @@ def encode_results(res): return encode_results(dircontents) else: - return self.__listdir(path) + return self.original_listdir(path) @cached_property def project_root_dagshub_path(self): return DagshubPath(self, self.project_root) @wrapreturn(DagshubScandirIterator) - def scandir(self, path="."): + def scandir(self, path: PathTypeWithDagshubPath = "."): # FD check if isinstance(path, int): - for direntry in self.__scandir(path): + for direntry in self.original_scandir(path): yield direntry return - # scandir needs to return name and path as bytes, if entry arg is bytes - is_bytes_path_arg = isinstance(path, bytes) - if is_bytes_path_arg: - str_path = os.fsdecode(path) - else: - str_path = path - parsed_path = DagshubPath(self, str_path) + + parsed_path = DagshubPath(self, path) + if parsed_path.is_in_repo and not parsed_path.is_passthrough_path(self): - path = Path(str_path) + path = Path(parsed_path.original_path) local_filenames = set() try: - for direntry in self.__scandir(path): + for direntry in self.original_scandir(path): local_filenames.add(direntry.name) yield direntry except FileNotFoundError: pass for special_entry in self._get_special_paths( - parsed_path, self.project_root_dagshub_path / path, is_bytes_path_arg + parsed_path, self.project_root_dagshub_path / path, parsed_path.is_binary_path_requested ): if special_entry.path not in local_filenames: yield special_entry @@ -643,9 +566,11 @@ def scandir(self, path="."): for f in resp: name = PurePosixPath(f.path).name if name not in local_filenames: - yield DagshubDirEntry(self, parsed_path / name, f.type == "dir", is_binary=is_bytes_path_arg) + yield DagshubDirEntry( + self, parsed_path / name, f.type == "dir", is_binary=parsed_path.is_binary_path_requested + ) else: - for entry in self.__scandir(path): + for entry in self.original_scandir(path): yield entry def _get_special_paths( @@ -658,6 +583,7 @@ def generate_entry(path, is_directory): has_storages = len(self._storages) > 0 res = set() + assert dh_path.relative_path is not None str_path = dh_path.relative_path.as_posix() if str_path == ".": res.add(generate_entry(SPECIAL_FILE, False)) @@ -707,24 +633,6 @@ def _check_listdir_cache(self, path: str, include_size: bool) -> Tuple[Optional[ return cache_val, True return None, False - def _content_url_for_path(self, path: DagshubPath): - if not path.is_in_repo: - raise RuntimeError(f"Can't access path {path.absolute_path} outside of repo") - str_path = path.relative_path.as_posix() - if path.is_storage_path: - path_to_access = str_path[len(".dagshub/storage/") :] - return self._api.storage_content_api_url(path_to_access) - return self._api.content_api_url(str_path, self._current_revision) - - def _raw_url_for_path(self, path: DagshubPath): - if not path.is_in_repo: - raise RuntimeError(f"Can't access path {path.absolute_path} outside of repo") - str_path = path.relative_path.as_posix() - if path.is_storage_path: - path_to_access = str_path[len(".dagshub/storage/") :] - return self._api.storage_raw_api_url(path_to_access) - return self._api.raw_api_url(str_path, self._current_revision) - def _api_download_file_git(self, path: DagshubPath) -> bytes: if path.relative_path is None: raise RuntimeError(f"Can't access path {path.absolute_path} outside of repo") @@ -734,229 +642,42 @@ def _api_download_file_git(self, path: DagshubPath) -> bytes: return self._api.get_storage_file(str_path) return self._api.get_file(str_path, self._current_revision) - def install_hooks(self): - """ - Install hooks to override default file and directory operations with DagsHub-aware functionality. - - This method patches the standard Python I/O operations such as ``open``, - ``stat``, ``listdir``, ``scandir``, and ``chdir`` with DagsHub-aware equivalents. - Works inside a notebook and with Pathlib. - - If ``install_hooks()`` have already been called before, this method does nothing. - - Example:: - - dagshub_fs = DagsHubFilesystem() - dagshub_fs.install_hooks() - - with open("src/file_in_repo.txt") as f: - print(f.read()) - - Call :func:`~DagsHubFilesystem.uninstall_hooks` to undo the monkey patching. - """ - if not hasattr(self.__class__, f"_{self.__class__.__name__}__unpatched"): - # TODO: DRY this dictionary. i.e. __open() links cls.__open - # and io.open even though this dictionary links them - # Cannot use a dict as the source of truth because type hints rely on - # __get_unpatched inferring the right type - self.__class__.__unpatched = { - "open": builtins.open, - "stat": os.stat, - "listdir": os.listdir, - "scandir": os.scandir, - "chdir": os.chdir, - } - if PRE_PYTHON3_11: - self.__class__.__unpatched["pathlib_open"] = _pathlib.open - - # IPython patches io.open to its own override, so we need to overwrite that also - # More at _modified_open function in IPython sources: - # https://github.com/ipython/ipython/blob/main/IPython/core/interactiveshell.py - if is_inside_notebook() and not is_inside_colab(): - import IPython.core.interactiveshell - - instance = IPython.core.interactiveshell.InteractiveShell._instance # noqa - if instance is not None and hasattr(instance, "user_ns") and "open" in instance.user_ns: - self.__class__.__unpatched["notebook_open"] = instance.user_ns["open"] - instance.user_ns["open"] = self.open - - io.open = builtins.open = self.open - os.stat = self.stat - os.listdir = self.listdir - os.scandir = self.scandir - os.chdir = self.chdir - if PRE_PYTHON3_11: - if sys.version_info.minor == 10: - # Python 3.10 - pathlib uses io.open - _pathlib.open = self.open - else: - # Python <=3.9 - pathlib uses os.open - _pathlib.open = self.os_open - _pathlib.stat = self.stat - _pathlib.listdir = self.listdir - _pathlib.scandir = self.scandir - - self._install_framework_hooks() - - DagsHubFilesystem.hooked_instance = self - - _framework_key_prefix = "framework_" - - def _install_framework_hooks(self): - """ - Installs custom hook functions for frameworks - """ - if self.frameworks is None: - return - for framework in self.frameworks: - if framework not in self._framework_override_map: - logger.warning(f"Framework {framework} not available for override, skipping") - continue - funcs = self._framework_override_map[framework] - for func in funcs: - module_name, func_name = func.rsplit(".", 1) - class_name = None - patch_class = None - - # Handle static class methods - we'll need to get the class from the module first - if "$" in module_name: - module_name, class_name = module_name.split("$") - # Get rid of the . in the module name - module_name = module_name[:-1] - - try: - patch_module = importlib.import_module(module_name) - if class_name is not None: - patch_class = getattr(patch_module, class_name) - orig_fn = getattr(patch_class, func_name) - else: - orig_fn = getattr(patch_module, func_name) - except ModuleNotFoundError: - logger.warning(f"Module [{module_name}] not found, so function [{func}] isn't being patched") - continue - except AttributeError: - logger.warning(f"Function [{func}] not found, not patching it") - continue - self.__class__.__unpatched[f"{self._framework_key_prefix}{func}"] = orig_fn - if patch_class is not None: - setattr(patch_class, func_name, self._passthrough_decorator(orig_fn)) - else: - setattr(patch_module, func_name, self._passthrough_decorator(orig_fn)) - - def _passthrough_decorator(self, orig_func, filearg: Union[int, str] = 0) -> Callable: - """ - Decorator function over some other random function that assumes a file exists locally, - but isn't using python's open(). These might be C++/Rust functions that use their respective opens. - Examples: opencv, anything using pyo3 - - Working around the problem by first calling open().close() to get the file. - - :param orig_func: the original function that needs to be called - :param filearg: int or string, which arg/kwarg to use to get the filename - :return: Wrapped orig_func - """ - - def passed_through(*args, **kwargs): - if type(filearg) is str: - filename = kwargs[filearg] - else: - filename = args[filearg] - self.open(filename).close() - return orig_func(*args, **kwargs) - - return passed_through - - @classmethod - def uninstall_hooks(cls): - """ - Reverses the changes made by :func:`install_hooks`, bringing back the builtin file I/O functions. - """ - if hasattr(cls, f"_{cls.__name__}__unpatched"): - io.open = builtins.open = cls.__unpatched["open"] - os.stat = cls.__unpatched["stat"] - os.listdir = cls.__unpatched["listdir"] - os.scandir = cls.__unpatched["scandir"] - os.chdir = cls.__unpatched["chdir"] - if PRE_PYTHON3_11: - _pathlib.open = cls.__unpatched["pathlib_open"] - _pathlib.stat = cls.__unpatched["stat"] - _pathlib.listdir = cls.__unpatched["listdir"] - _pathlib.scandir = cls.__unpatched["scandir"] - - if "notebook_open" in cls.__unpatched: - import IPython.core.interactiveshell - - instance = IPython.core.interactiveshell.InteractiveShell._instance # noqa - if instance is not None and hasattr(instance, "user_ns"): - instance.user_ns["open"] = cls.__unpatched["notebook_open"] - - cls._uninstall_framework_hooks() - - if DagsHubFilesystem.hooked_instance is not None: - DagsHubFilesystem.hooked_instance.cleanup() - DagsHubFilesystem.hooked_instance = None - - @classmethod - def _uninstall_framework_hooks(cls): - for func in list(filter(lambda key: key.startswith(cls._framework_key_prefix), cls.__unpatched.keys())): - orig_fn = cls.__unpatched[func] - orig_func_name = func - - func = func[len(cls._framework_key_prefix) :] - module_name, func_name = func.rsplit(".", 1) - class_name = None - - if "$" in module_name: - module_name, class_name = module_name.split("$") - # Get rid of the . in the module name - module_name = module_name[:-1] - - m = importlib.import_module(module_name) - if class_name is not None: - patch_class = getattr(m, class_name) - setattr(patch_class, func_name, orig_fn) - else: - setattr(m, func_name, orig_fn) - - del cls.__unpatched[orig_func_name] - - def _mkdirs(self, absolute_path: Path): + def mkdirs(self, absolute_path: Path): for parent in list(absolute_path.parents)[::-1]: try: - self.__stat(parent) + self.original_stat(parent) except (OSError, ValueError): os.mkdir(parent) try: - self.__stat(absolute_path) + self.original_stat(absolute_path) except (OSError, ValueError): os.mkdir(absolute_path) - @classmethod - def __get_unpatched(cls, key, alt: T) -> T: - if hasattr(cls, f"_{cls.__name__}__unpatched"): - return cls.__unpatched[key] - else: - return alt - @property - def __open(self): - return self.__get_unpatched("open", builtins.open) + def original_open(self): + return HookRouter.original_open @property - def __stat(self): - return self.__get_unpatched("stat", os.stat) + def original_stat(self): + return HookRouter.original_stat @property - def __listdir(self): - return self.__get_unpatched("listdir", os.listdir) + def original_listdir(self): + return HookRouter.original_listdir @property - def __scandir(self): - return self.__get_unpatched("scandir", os.scandir) + def original_scandir(self): + return HookRouter.original_scandir @property - def __chdir(self): - return self.__get_unpatched("chdir", os.chdir) + def original_chdir(self): + return HookRouter.original_chdir + + def install_hooks(self): + HookRouter.hook_repo(self, frameworks=self.frameworks) + + def uninstall_hooks(self): + HookRouter.unhook_repo(self) def install_hooks( @@ -993,115 +714,14 @@ def install_hooks( exclude_globs=exclude_globs, frameworks=frameworks, ) - fs.install_hooks() + HookRouter.hook_repo(fs, frameworks) def uninstall_hooks(): """ Reverses the changes made by :func:`install_hooks` """ - DagsHubFilesystem.uninstall_hooks() - - -class DagshubStatResult: - def __init__(self, fs: "DagsHubFilesystem", path: DagshubPath, is_directory: bool, custom_size: int = None): - self._fs = fs - self._path = path - self._is_directory = is_directory - self._custom_size = custom_size - assert not self._is_directory # TODO make folder stats lazy? - - def __getattr__(self, name: str): - if not name.startswith("st_"): - raise AttributeError - if hasattr(self, "_true_stat"): - return os.stat_result.__getattribute__(self._true_stat, name) - if name == "st_uid": - return os.getuid() - elif name == "st_gid": - return os.getgid() - elif name == "st_atime" or name == "st_mtime" or name == "st_ctime": - return 0 - elif name == "st_mode": - return 0o100644 - elif name == "st_size": - if self._custom_size: - return self._custom_size - return 1100 # hardcoded size because size requests take a disproportionate amount of time - self._fs.open(self._path) - self._true_stat = self._fs._DagsHubFilesystem__stat(self._path.absolute_path) - return os.stat_result.__getattribute__(self._true_stat, name) - - def __repr__(self): - inner = repr(self._true_stat) if hasattr(self, "_true_stat") else "pending..." - return f"dagshub_stat_result({inner}, path={self._path})" - - -class DagshubDirEntry: - def __init__(self, fs: "DagsHubFilesystem", path: DagshubPath, is_directory: bool = False, is_binary: bool = False): - self._fs = fs - self._path = path - self._is_directory = is_directory - self._is_binary = is_binary - - @property - def name(self): - # TODO: create decorator for delegation - if hasattr(self, "_true_direntry"): - name = self._true_direntry.name - else: - name = self._path.name - return os.fsencode(name) if self._is_binary else name - - @property - def path(self): - if hasattr(self, "_true_direntry"): - path = self._true_direntry.path - else: - path = str(self._path.original_path) - return os.fsencode(path) if self._is_binary else path - - def is_dir(self): - if hasattr(self, "_true_direntry"): - return self._true_direntry.is_dir() - else: - return self._is_directory - - def is_file(self): - if hasattr(self, "_true_direntry"): - return self._true_direntry.is_file() - else: - # TODO: Symlinks should return false - return not self._is_directory - - def stat(self): - if hasattr(self, "_true_direntry"): - return self._true_direntry.stat() - else: - return self._fs.stat(self._path.original_path) - - def __getattr__(self, name: str): - if name == "_true_direntry": - raise AttributeError - if hasattr(self, "_true_direntry"): - return os.DirEntry.__getattribute__(self._true_direntry, name) - - # Either create a dir, or download the file - if self._is_directory: - self._fs._mkdirs(self._path.absolute_path) - else: - self._fs.open(self._path.absolute_path) - - for direntry in self._fs._DagsHubFilesystem__scandir(self._path.original_path): - if direntry.name == self._path.name: - self._true_direntry = direntry - return os.DirEntry.__getattribute__(self._true_direntry, name) - else: - raise FileNotFoundError - - def __repr__(self): - cached = " (cached)" if hasattr(self, "_true_direntry") else "" - return f"" + HookRouter.uninstall_monkey_patch() __all__ = [DagsHubFilesystem.__name__, install_hooks.__name__] diff --git a/dagshub/streaming/hook_router.py b/dagshub/streaming/hook_router.py index 8ca20b61..588f95b3 100644 --- a/dagshub/streaming/hook_router.py +++ b/dagshub/streaming/hook_router.py @@ -1,12 +1,318 @@ +import builtins +import importlib +import io +import logging +import os +import sys from os import PathLike -from typing import Union, Optional +from typing import Union, Optional, Callable, Dict, List, TYPE_CHECKING -from dagshub.streaming import DagsHubFilesystem +from dagshub.common import is_inside_notebook, is_inside_colab +from dagshub.streaming.dataclasses import DagshubPath, PathType +from dagshub.streaming.filesystem import DagshubScandirIterator +from dagshub.streaming.util import wrapreturn + +if TYPE_CHECKING: + from dagshub.streaming.filesystem import DagsHubFilesystem + +# Pre 3.11 - need to patch _NormalAccessor for _pathlib, because it pre-caches open and other functions. +# In 3.11 _NormalAccessor was removed +PRE_PYTHON3_11 = sys.version_info.major == 3 and sys.version_info.minor < 11 +if PRE_PYTHON3_11: + from pathlib import _NormalAccessor as _pathlib # noqa + + +logger = logging.getLogger(__name__) class HookRouter: - def install_hooks(self, fs: DagsHubFilesystem): - pass + original_open = builtins.open + original_stat = os.stat + original_listdir = os.listdir + original_scandir = os.scandir + original_chdir = os.chdir + + unpatched: Dict[str, Callable] = {} + + is_monkey_patching = False + + active_filesystems: List["DagsHubFilesystem"] = [] + + # Framework-specific override functions. + # These functions will be patched with a function that calls fs.open() before calling the original function + # Classes are marked by $, so if you need to change a static/class method, use module.$class.func + _framework_override_map: Dict[str, List[str]] = { + "transformers": ["safetensors.safe_open", "tokenizers.$Tokenizer.from_file"], + } + + @classmethod + def open( + cls, + file: Union[str, int, bytes, PathLike, DagshubPath], + mode="r", + buffering=-1, + encoding=None, + errors=None, + newline=None, + closefd=True, + opener=None, + ): + if isinstance(file, int): + return cls.original_open(file, mode, buffering, encoding, errors, newline, closefd, opener) + dh_path = cls.determine_fs(file) + if dh_path is not None: + return dh_path.fs.open(dh_path, mode, buffering, encoding, errors, newline, closefd, opener) + else: + return cls.original_open(file, mode, buffering, encoding, errors, newline, closefd, opener) + + @classmethod + def stat(cls, path: PathType, *args, dir_fd=None, follow_symlinks=True): + if isinstance(path, int): + return cls.original_stat(path, *args, dir_fd=dir_fd, follow_symlinks=follow_symlinks) + dh_path = cls.determine_fs(path) + if dh_path is not None: + return dh_path.fs.stat(dh_path, *args, dir_fd=dir_fd, follow_symlinks=follow_symlinks) + else: + return cls.original_stat(path, *args, dir_fd=dir_fd, follow_symlinks=follow_symlinks) + + @classmethod + def listdir(cls, path: PathType = "."): + if isinstance(path, int): + return cls.original_listdir(path) + dh_path = cls.determine_fs(path) + if dh_path is not None: + return dh_path.fs.listdir(dh_path) + else: + return cls.original_listdir(path) + + @classmethod + @wrapreturn(DagshubScandirIterator) + def scandir(cls, path: PathType = "."): + if isinstance(path, int): + return cls.original_scandir(path) + dh_path = cls.determine_fs(path) + if dh_path is not None: + return dh_path.fs.scandir(path) + else: + return cls.original_scandir(path) + + @classmethod + def chdir(cls, path: PathType): + if isinstance(path, int): + return cls.original_chdir(path) + dh_path = cls.determine_fs(path) + if dh_path is not None: + return dh_path.fs.chdir(dh_path) + else: + return cls.original_chdir(path) + + @classmethod + def os_open(cls, path: Union[str, bytes, PathLike], flags, mode=0o777, *args, dir_fd=None): + dh_path = cls.determine_fs(path) + if dh_path is not None: + return dh_path.fs.os_open(dh_path, flags, mode, *args, dir_fd=dir_fd) + else: + return os.open(path, flags, mode, *args, dir_fd=dir_fd) + + @classmethod + def init_monkey_patch(cls, frameworks: Optional[List[str]] = None): + if cls.is_monkey_patching: + return + # Save the current unpatched functions + cls.unpatched = { + "open": cls.original_open, + "stat": cls.original_stat, + "listdir": cls.original_listdir, + "scandir": cls.original_scandir, + "chdir": cls.original_chdir, + } + if PRE_PYTHON3_11: + cls.unpatched["pathlib_open"] = _pathlib.open + + # IPython patches io.open to its own override, so we need to overwrite that also + # More at _modified_open function in IPython sources: + # https://github.com/ipython/ipython/blob/main/IPython/core/interactiveshell.py + if is_inside_notebook() and not is_inside_colab(): + import IPython.core.interactiveshell + + instance = IPython.core.interactiveshell.InteractiveShell._instance # noqa + if instance is not None and hasattr(instance, "user_ns") and "open" in instance.user_ns: + cls.unpatched["notebook_open"] = instance.user_ns["open"] + instance.user_ns["open"] = cls.open + + io.open = builtins.open = cls.open + os.stat = cls.stat + os.listdir = cls.listdir + os.scandir = cls.scandir + os.chdir = cls.chdir + if PRE_PYTHON3_11: + if sys.version_info.minor == 10: + # Python 3.10 - pathlib uses io.open + _pathlib.open = cls.open + else: + # Python <=3.9 - pathlib uses os.open + _pathlib.open = cls.os_open + _pathlib.stat = cls.stat + _pathlib.listdir = cls.listdir + _pathlib.scandir = cls.scandir + + cls._install_framework_hooks(frameworks) + cls.is_monkey_patching = True + + @classmethod + def uninstall_monkey_patch(cls): + if not cls.is_monkey_patching: + return + io.open = builtins.open = cls.unpatched["open"] + os.stat = cls.unpatched["stat"] + os.listdir = cls.unpatched["listdir"] + os.scandir = cls.unpatched["scandir"] + os.chdir = cls.unpatched["chdir"] + if PRE_PYTHON3_11: + _pathlib.open = cls.unpatched["pathlib_open"] + _pathlib.stat = cls.unpatched["stat"] + _pathlib.listdir = cls.unpatched["listdir"] + _pathlib.scandir = cls.unpatched["scandir"] + + if "notebook_open" in cls.unpatched: + import IPython.core.interactiveshell + + instance = IPython.core.interactiveshell.InteractiveShell._instance # noqa + if instance is not None and hasattr(instance, "user_ns"): + instance.user_ns["open"] = cls.unpatched["notebook_open"] + + cls._uninstall_framework_hooks() + cls.active_filesystems.clear() + cls.is_monkey_patching = False + + _framework_key_prefix = "framework_" + + @classmethod + def _install_framework_hooks(cls, frameworks: Optional[List[str]]): + """ + Installs custom hook functions for frameworks + """ + if frameworks is None: + return + for framework in frameworks: + if framework not in cls._framework_override_map: + logger.warning(f"Framework {framework} not available for override, skipping") + continue + funcs = cls._framework_override_map[framework] + for func in funcs: + module_name, func_name = func.rsplit(".", 1) + class_name = None + patch_class = None + + # Handle static class methods - we'll need to get the class from the module first + if "$" in module_name: + module_name, class_name = module_name.split("$") + # Get rid of the . in the module name + module_name = module_name[:-1] + + try: + patch_module = importlib.import_module(module_name) + if class_name is not None: + patch_class = getattr(patch_module, class_name) + orig_fn = getattr(patch_class, func_name) + else: + orig_fn = getattr(patch_module, func_name) + except ModuleNotFoundError: + logger.warning(f"Module [{module_name}] not found, so function [{func}] isn't being patched") + continue + except AttributeError: + logger.warning(f"Function [{func}] not found, not patching it") + continue + cls.unpatched[f"{cls._framework_key_prefix}{func}"] = orig_fn + if patch_class is not None: + setattr(patch_class, func_name, cls._passthrough_decorator(orig_fn)) + else: + setattr(patch_module, func_name, cls._passthrough_decorator(orig_fn)) + + @classmethod + def _uninstall_framework_hooks(cls): + for func in list(filter(lambda key: key.startswith(cls._framework_key_prefix), cls.unpatched.keys())): + orig_fn = cls.unpatched[func] + orig_func_name = func + + func = func[len(cls._framework_key_prefix) :] + module_name, func_name = func.rsplit(".", 1) + class_name = None + + if "$" in module_name: + module_name, class_name = module_name.split("$") + # Get rid of the . in the module name + module_name = module_name[:-1] + + m = importlib.import_module(module_name) + if class_name is not None: + patch_class = getattr(m, class_name) + setattr(patch_class, func_name, orig_fn) + else: + setattr(m, func_name, orig_fn) + + del cls.unpatched[orig_func_name] + + @classmethod + def _passthrough_decorator(cls, orig_func, filearg: Union[int, str] = 0) -> Callable: + """ + Decorator function over some other random function that assumes a file exists locally, + but isn't using python's open(). These might be C++/Rust functions that use their respective opens. + Examples: opencv, anything using pyo3 + + Working around the problem by first calling open().close() to get the file. + + :param orig_func: the original function that needs to be called + :param filearg: int or string, which arg/kwarg to use to get the filename + :return: Wrapped orig_func + """ + + def passed_through(*args, **kwargs): + if isinstance(filearg, str): + filename = kwargs[filearg] + else: + filename = args[filearg] + cls.open(filename).close() + return orig_func(*args, **kwargs) + + return passed_through + + @staticmethod + def _dagshub_path_relative_length(dhp: DagshubPath) -> int: + if dhp.relative_path is None: + raise RuntimeError(f"Tried to get length of the nonexistent relative path for dagshub path {dhp}") + return len(dhp.relative_path.parents) + + @classmethod + def determine_fs(cls, path: Union[str, bytes, PathLike]) -> Optional[DagshubPath]: + """ + Determine the hooked filesystem that path belongs to + If it belongs to multiple filesystems, then the one with the most specific path will be returned + + If file doesn't belong to any filesystem, then returns None + """ + possible_paths = [] + for fs in cls.active_filesystems: + parsed = DagshubPath(fs, path) + if parsed.is_in_repo: + possible_paths.append(parsed) + if len(possible_paths) > 0: + # Return the path that has the shortest relative path (most specific) + return min(possible_paths, key=cls._dagshub_path_relative_length) + return None + + @classmethod + def hook_repo(cls, fs: "DagsHubFilesystem", frameworks: Optional[List[str]]): + if not cls.is_monkey_patching: + cls.init_monkey_patch(frameworks) + cls.active_filesystems.append(fs) + + @classmethod + def unhook_repo(cls, fs: Optional["DagsHubFilesystem"] = None, path: Optional[Union[str, PathLike]] = None): + if fs in cls.active_filesystems: + cls.active_filesystems.remove(fs) + if path is not None: + raise NotImplementedError("Unhooking by path is not implemented yet") - def uninstall_hooks(self, fs: Optional[DagsHubFilesystem]=None, path: Optional[Union[str, PathLike]]=None): - pass + if len(cls.active_filesystems): + cls.uninstall_monkey_patch() diff --git a/dagshub/streaming/util.py b/dagshub/streaming/util.py new file mode 100644 index 00000000..6eb9031c --- /dev/null +++ b/dagshub/streaming/util.py @@ -0,0 +1,12 @@ +from functools import wraps + + +def wrapreturn(wrapper_type): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + return wrapper_type(func(*args, **kwargs)) + + return wrapper + + return decorator diff --git a/requirements-dev.txt b/requirements-dev.txt index f26c533c..b2bc3dce 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -4,3 +4,7 @@ respx==0.20.2 pytest-git==1.7.0 pytest-env==1.1.3 fiftyone==0.23.5 +ruff==0.3.2 +mypy==1.8.0 +types-python-dateutil==2.8.19.14 +types-pytz==2023.3.1.1 diff --git a/tests/dda/filesystem/test_multihooks.py b/tests/dda/filesystem/test_multihooks.py index b51177b1..9da3d12a 100644 --- a/tests/dda/filesystem/test_multihooks.py +++ b/tests/dda/filesystem/test_multihooks.py @@ -55,26 +55,53 @@ def test_mock_fs_works(repo_1, tmp_path): pass -def test_two_mock_fs(repo_1, repo_2, tmp_path): - path1 = tmp_path / "repo1" - path2 = tmp_path / "repo2" +@pytest.mark.parametrize( + "repo_1_dir, repo_2_dir", [("repo1", "repo2"), ("mount", "mount/repo2"), ("mount/repo1", "mount")] +) +def test_two_mock_fs(repo_1, repo_2, tmp_path, repo_1_dir, repo_2_dir): + path1 = tmp_path / repo_1_dir + path2 = tmp_path / repo_2_dir fs1 = generate_mock_fs(repo_1, path1) fs2 = generate_mock_fs(repo_2, path2) - assert fs1.open(path1 / "a/b.txt", "rb").read() == b"content repo 1" - assert fs2.open(path2 / "a/b.txt", "rb").read() == b"content repo 2" + try: + fs1.install_hooks() + fs2.install_hooks() + assert open(path1 / "a/b.txt", "rb").read() == b"content repo 1" + assert open(path2 / "a/b.txt", "rb").read() == b"content repo 2" + finally: + uninstall_hooks() + + +def test_nesting_priority(repo_1, repo_2, tmp_path): + path1 = tmp_path / "mount" + path2 = tmp_path / "mount/repo2" + + repo_1.add_repo_file("repo2/a/b.txt", b"FAILED") -def test_install_hooks_two_fs(repo_1, repo_2, tmp_path): - path1 = tmp_path / "repo1" - path2 = tmp_path / "repo2" fs1 = generate_mock_fs(repo_1, path1) fs2 = generate_mock_fs(repo_2, path2) - try: fs1.install_hooks() fs2.install_hooks() - assert open(path1 / "a/b.txt", "rb").read() == b"content repo 1" + assert open(path2 / "a/b.txt", "rb").read() == b"content repo 2" + finally: + uninstall_hooks() + + +def test_nesting_priority_reverse_order(repo_1, repo_2, tmp_path): + path1 = tmp_path / "mount" + path2 = tmp_path / "mount/repo2" + + repo_1.add_repo_file("repo2/a/b.txt", b"FAILED") + + fs1 = generate_mock_fs(repo_1, path1) + fs2 = generate_mock_fs(repo_2, path2) + try: + fs2.install_hooks() + fs1.install_hooks() + assert open(path2 / "a/b.txt", "rb").read() == b"content repo 2" finally: uninstall_hooks() diff --git a/tests/dda/test_tokens.py b/tests/dda/test_tokens.py index 6c91553d..503d140a 100644 --- a/tests/dda/test_tokens.py +++ b/tests/dda/test_tokens.py @@ -111,7 +111,6 @@ def token_cleanup_test(token_cache, token): if token.token_text in line: failed = True break - print(token_cache.cache_location) assert not failed From 229f1b0f610ada28b78b95057fcd84a65913ac05 Mon Sep 17 00:00:00 2001 From: KBolashev Date: Thu, 14 Mar 2024 14:17:44 +0200 Subject: [PATCH 05/13] Enhance install/uninstall hooks, write tests that handle most of the interesting cases --- dagshub/streaming/__init__.py | 10 ++- dagshub/streaming/filesystem.py | 57 ++++++++++++----- dagshub/streaming/hook_router.py | 46 +++++++++++--- tests/dda/filesystem/test_misc.py | 17 ----- tests/dda/filesystem/test_multihooks.py | 83 ++++++++++++++++++++++++- tests/dda/filesystem/test_open.py | 2 +- 6 files changed, 170 insertions(+), 45 deletions(-) diff --git a/dagshub/streaming/__init__.py b/dagshub/streaming/__init__.py index 65fb5b71..174ba7ec 100644 --- a/dagshub/streaming/__init__.py +++ b/dagshub/streaming/__init__.py @@ -1,4 +1,4 @@ -from .filesystem import DagsHubFilesystem, install_hooks, uninstall_hooks +from .filesystem import DagsHubFilesystem, install_hooks, uninstall_hooks, get_mounted_filesystems try: from .mount import mount @@ -15,4 +15,10 @@ def mount(*args, **kwargs): print(error) -__all__ = [DagsHubFilesystem.__name__, install_hooks.__name__, mount.__name__, uninstall_hooks.__name__] +__all__ = [ + DagsHubFilesystem.__name__, + install_hooks.__name__, + mount.__name__, + uninstall_hooks.__name__, + get_mounted_filesystems.__name__, +] diff --git a/dagshub/streaming/filesystem.py b/dagshub/streaming/filesystem.py index 7e8f32a0..a18c0ae5 100644 --- a/dagshub/streaming/filesystem.py +++ b/dagshub/streaming/filesystem.py @@ -75,6 +75,7 @@ def __init__( exclude_globs: Optional[Union[List[str], str]] = None, frameworks: Optional[List[str]] = None, ): + self.project_root: Path # Find root directory of Git project if not project_root: try: @@ -121,7 +122,7 @@ def __init__( self._listdir_cache: Dict[str, Tuple[Optional[List[ContentAPIEntry]], bool]] = {} - self._api = self._generate_repo_api(self.parsed_repo_url) + self.repo_api = self._generate_repo_api(self.parsed_repo_url) # Check that the repo is accessible by accessing the content root response = self._api_listdir(DagshubPath(self, self.project_root)) @@ -129,7 +130,7 @@ def __init__( # TODO: Check .dvc/config{,.local} for credentials raise AuthenticationError("DagsHub credentials required, however none provided or discovered") - self._storages = self._api.get_connected_storages() + self._storages = self.repo_api.get_connected_storages() def _generate_repo_api(self, repo_url: ParseResult) -> RepoAPI: host = f"{repo_url.scheme}://{repo_url.netloc}" @@ -137,7 +138,7 @@ def _generate_repo_api(self, repo_url: ParseResult) -> RepoAPI: return RepoAPI(repo=repo, host=host, auth=self.auth) @cached_property - def _current_revision(self) -> str: + def current_revision(self) -> str: """ Gets current revision on repo: - If User specified a branch, returns HEAD of that brunch on the remote @@ -169,22 +170,22 @@ def _current_revision(self) -> str: "Couldn't get branch info from local git repository, " + "fetching default branch from the remote..." ) - branch = self._api.default_branch + branch = self.repo_api.default_branch # check if it is a commit sha, in that case do not load the sha sha_regex = re.compile(r"^([a-f0-9]){5,40}$") if sha_regex.match(branch): try: - self._api.get_commit_info(branch) + self.repo_api.get_commit_info(branch) return branch except CommitNotFoundError: pass - return self._api.last_commit_sha(branch) + return self.repo_api.last_commit_sha(branch) def is_commit_on_remote(self, sha1): try: - self._api.get_commit_info(sha1) + self.repo_api.get_commit_info(sha1) return True except CommitNotFoundError: return False @@ -612,9 +613,9 @@ def _api_listdir(self, path: DagshubPath, include_size: bool = False) -> Optiona try: if path.is_storage_path: storage_path = repo_path[len(".dagshub/storage/") :] - res = self._api.list_storage_path(storage_path, include_size=include_size) + res = self.repo_api.list_storage_path(storage_path, include_size=include_size) else: - res = self._api.list_path(repo_path, self._current_revision, include_size=include_size) + res = self.repo_api.list_path(repo_path, self.current_revision, include_size=include_size) except PathNotFoundError: self._listdir_cache[repo_path] = (None, True) return None @@ -639,8 +640,8 @@ def _api_download_file_git(self, path: DagshubPath) -> bytes: str_path = path.relative_path.as_posix() if path.is_storage_path: str_path = str_path[len(".dagshub/storage/") :] - return self._api.get_storage_file(str_path) - return self._api.get_file(str_path, self._current_revision) + return self.repo_api.get_storage_file(str_path) + return self.repo_api.get_file(str_path, self.current_revision) def mkdirs(self, absolute_path: Path): for parent in list(absolute_path.parents)[::-1]: @@ -653,6 +654,9 @@ def mkdirs(self, absolute_path: Path): except (OSError, ValueError): os.mkdir(absolute_path) + def __del__(self): + self.uninstall_hooks() + @property def original_open(self): return HookRouter.original_open @@ -677,7 +681,7 @@ def install_hooks(self): HookRouter.hook_repo(self, frameworks=self.frameworks) def uninstall_hooks(self): - HookRouter.unhook_repo(self) + HookRouter.unhook_repo(fs=self) def install_hooks( @@ -717,11 +721,34 @@ def install_hooks( HookRouter.hook_repo(fs, frameworks) -def uninstall_hooks(): +def uninstall_hooks(fs: Optional["DagsHubFilesystem"] = None, path: Optional[Union[str, PathLike]] = None): """ Reverses the changes made by :func:`install_hooks` + You can specify a filesystem or a path to unhook just one specific filesystem + If nothing is specified, all current hooks will be cancelled + + Args: + fs: DagsHubFilesystem + """ + if fs is not None or path is not None: + HookRouter.unhook_repo(fs=fs, path=path) + else: + # Uninstall everything + HookRouter.uninstall_monkey_patch() + + +def get_mounted_filesystems() -> List[Tuple[Path, "DagsHubFilesystem"]]: + """ + Returns all currently mounted filesystems + Returns: + List of tuples of (, ) """ - HookRouter.uninstall_monkey_patch() + return [(fs.project_root, fs) for fs in HookRouter.active_filesystems] -__all__ = [DagsHubFilesystem.__name__, install_hooks.__name__] +__all__ = [ + DagsHubFilesystem.__name__, + install_hooks.__name__, + uninstall_hooks.__name__, + get_mounted_filesystems.__name__, +] diff --git a/dagshub/streaming/hook_router.py b/dagshub/streaming/hook_router.py index 588f95b3..9cd3cd19 100644 --- a/dagshub/streaming/hook_router.py +++ b/dagshub/streaming/hook_router.py @@ -5,13 +5,16 @@ import os import sys from os import PathLike -from typing import Union, Optional, Callable, Dict, List, TYPE_CHECKING +from typing import Union, Optional, Callable, Dict, List, TYPE_CHECKING, Set from dagshub.common import is_inside_notebook, is_inside_colab from dagshub.streaming.dataclasses import DagshubPath, PathType +from dagshub.streaming.errors import FilesystemAlreadyMountedError from dagshub.streaming.filesystem import DagshubScandirIterator from dagshub.streaming.util import wrapreturn +from pathlib import Path + if TYPE_CHECKING: from dagshub.streaming.filesystem import DagsHubFilesystem @@ -21,7 +24,6 @@ if PRE_PYTHON3_11: from pathlib import _NormalAccessor as _pathlib # noqa - logger = logging.getLogger(__name__) @@ -36,7 +38,7 @@ class HookRouter: is_monkey_patching = False - active_filesystems: List["DagsHubFilesystem"] = [] + active_filesystems: Set["DagsHubFilesystem"] = set() # Framework-specific override functions. # These functions will be patched with a function that calls fs.open() before calling the original function @@ -48,7 +50,7 @@ class HookRouter: @classmethod def open( cls, - file: Union[str, int, bytes, PathLike, DagshubPath], + file: PathType, mode="r", buffering=-1, encoding=None, @@ -301,18 +303,44 @@ def determine_fs(cls, path: Union[str, bytes, PathLike]) -> Optional[DagshubPath return min(possible_paths, key=cls._dagshub_path_relative_length) return None + @classmethod + def get_fs_by_path(cls, path: Union[str, PathLike]) -> Optional["DagsHubFilesystem"]: + fs: Optional["DagsHubFilesystem"] = None + path = Path(os.path.abspath(path)) + for active_fs in cls.active_filesystems: + if active_fs.project_root == path: + fs = active_fs + break + return fs + @classmethod def hook_repo(cls, fs: "DagsHubFilesystem", frameworks: Optional[List[str]]): if not cls.is_monkey_patching: cls.init_monkey_patch(frameworks) - cls.active_filesystems.append(fs) + + existing_fs = cls.get_fs_by_path(fs.project_root) + if existing_fs is not None: + raise FilesystemAlreadyMountedError( + existing_fs.project_root, existing_fs.repo_api.full_name, existing_fs.current_revision + ) + + cls.active_filesystems.add(fs) @classmethod def unhook_repo(cls, fs: Optional["DagsHubFilesystem"] = None, path: Optional[Union[str, PathLike]] = None): - if fs in cls.active_filesystems: - cls.active_filesystems.remove(fs) + if fs is None and path is None: + raise AttributeError("Only one of `fs` or `path` should be specified at the same time") + + # Find a filesystem by path if path is not None: - raise NotImplementedError("Unhooking by path is not implemented yet") + fs = cls.get_fs_by_path(path) + if fs is None: + raise RuntimeError(f"No filesystem mounted at {path}") + + # Unhook the fs + if fs is not None and fs in cls.active_filesystems: + cls.active_filesystems.remove(fs) - if len(cls.active_filesystems): + # If there are no more filesystems anymore, unhook + if len(cls.active_filesystems) == 0: cls.uninstall_monkey_patch() diff --git a/tests/dda/filesystem/test_misc.py b/tests/dda/filesystem/test_misc.py index 6ac938bc..43e08517 100644 --- a/tests/dda/filesystem/test_misc.py +++ b/tests/dda/filesystem/test_misc.py @@ -8,7 +8,6 @@ from dagshub.streaming import DagsHubFilesystem from dagshub.streaming.dataclasses import DagshubPath -from dagshub.streaming.errors import FilesystemAlreadyMountedError @pytest.mark.parametrize( @@ -33,22 +32,6 @@ def test_passthrough_path(path, expected): assert actual == expected -@pytest.mark.parametrize( - "a_path, b_path, create_folder", - [(".", ".", False), (".", "./subpath", True), (".", "../", False)], - ids=["Same dir", "Sub dir", "Parent dir"], -) -def test_cant_mount_multiples(mock_api, a_path, b_path, create_folder): - new_branch = "new" - sha = secrets.token_hex(nbytes=20) - mock_api.add_branch(new_branch, sha) - _ = DagsHubFilesystem(project_root=a_path, repo_url="https://dagshub.com/user/repo") - if create_folder: - os.makedirs(b_path, exist_ok=True) - with pytest.raises(FilesystemAlreadyMountedError): - _ = DagsHubFilesystem(project_root=b_path, repo_url="https://dagshub.com/user/repo", branch=new_branch) - - def test_can_mount_multiple_in_different_dirs(mock_api): new_branch = "new" sha = secrets.token_hex(nbytes=20) diff --git a/tests/dda/filesystem/test_multihooks.py b/tests/dda/filesystem/test_multihooks.py index 9da3d12a..f53bb644 100644 --- a/tests/dda/filesystem/test_multihooks.py +++ b/tests/dda/filesystem/test_multihooks.py @@ -3,7 +3,8 @@ import pytest -from dagshub.streaming import DagsHubFilesystem, uninstall_hooks +from dagshub.streaming import DagsHubFilesystem, uninstall_hooks, get_mounted_filesystems +from dagshub.streaming.errors import FilesystemAlreadyMountedError from tests.mocks.repo_api import MockRepoAPI @@ -43,6 +44,21 @@ def mocked(_self: DagsHubFilesystem, _path): return mocked +@pytest.fixture +def mock_fs_1(repo_1, tmp_path) -> DagsHubFilesystem: + mock_fs = generate_mock_fs(repo_1, tmp_path / repo_1.repo_name) + yield mock_fs + # Uninstall hooks in the end to be sure that it didn't get left over + mock_fs.uninstall_hooks() + + +@pytest.fixture +def mock_fs_2(repo_2, tmp_path) -> DagsHubFilesystem: + mock_fs = generate_mock_fs(repo_2, tmp_path / repo_2.repo_name) + yield mock_fs + mock_fs.uninstall_hooks() + + def generate_mock_fs(repo_api: MockRepoAPI, file_dir: Path) -> DagsHubFilesystem: with patch("dagshub.streaming.DagsHubFilesystem._generate_repo_api", mock_repo_api_patch(repo_api)): fs = DagsHubFilesystem(project_root=file_dir, repo_url="https://localhost.invalid") @@ -105,3 +121,68 @@ def test_nesting_priority_reverse_order(repo_1, repo_2, tmp_path): assert open(path2 / "a/b.txt", "rb").read() == b"content repo 2" finally: uninstall_hooks() + + +def test_cant_hook_in_the_same_folder(repo_1, repo_2, tmp_path): + path1 = tmp_path / "mount" + path2 = tmp_path / "mount" + + fs1 = generate_mock_fs(repo_1, path1) + fs2 = generate_mock_fs(repo_2, path2) + + try: + fs1.install_hooks() + with pytest.raises(FilesystemAlreadyMountedError): + fs2.install_hooks() + + finally: + uninstall_hooks() + + +def test_initial_state_has_no_hooks(): + assert len(get_mounted_filesystems()) == 0 + + +def test_install_hooks_adds_to_list_of_active(mock_fs_1): + mock_fs_1.install_hooks() + mounted_fses = get_mounted_filesystems() + assert len(mounted_fses) == 1 + assert mounted_fses[0][1] == mock_fs_1 + + +def test_uninstall_hooks_removes_from_list_of_active(mock_fs_1): + mock_fs_1.install_hooks() + mock_fs_1.uninstall_hooks() + assert len(get_mounted_filesystems()) == 0 + + +def test_global_uninstall_hooks_removes_all_by_default(mock_fs_1, mock_fs_2): + mock_fs_1.install_hooks() + mock_fs_2.install_hooks() + uninstall_hooks() + assert len(get_mounted_filesystems()) == 0 + + +def test_global_uninstall_hooks_remove_by_fs(mock_fs_1, mock_fs_2): + mock_fs_1.install_hooks() + mock_fs_2.install_hooks() + uninstall_hooks(fs=mock_fs_1) + mounted_fses = get_mounted_filesystems() + assert len(mounted_fses) == 1 + assert mounted_fses[0][1] == mock_fs_2 + + +def test_global_uninstall_hooks_remove_by_path(mock_fs_1, mock_fs_2): + mock_fs_1.install_hooks() + mock_fs_2.install_hooks() + uninstall_hooks(path=mock_fs_1.project_root) + mounted_fses = get_mounted_filesystems() + assert len(mounted_fses) == 1 + assert mounted_fses[0][1] == mock_fs_2 + + +def test_cant_access_after_uninstall_hooks(mock_fs_1): + mock_fs_1.install_hooks() + mock_fs_1.uninstall_hooks() + with pytest.raises(FileNotFoundError): + open(mock_fs_1.project_root / "a/b.txt") diff --git a/tests/dda/filesystem/test_open.py b/tests/dda/filesystem/test_open.py index f123d195..1a78dd0a 100644 --- a/tests/dda/filesystem/test_open.py +++ b/tests/dda/filesystem/test_open.py @@ -5,7 +5,7 @@ def test_sets_current_revision(mock_api): fs = DagsHubFilesystem() - assert fs._current_revision == mock_api.current_revision + assert fs.current_revision == mock_api.current_revision assert mock_api["branch"].called From 351181b5fe8b4bb6fc5e78058733ae9ff6eb828d Mon Sep 17 00:00:00 2001 From: KBolashev Date: Thu, 14 Mar 2024 14:19:26 +0200 Subject: [PATCH 06/13] Change error help for FilesystemAlreadyMountedError --- dagshub/streaming/errors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dagshub/streaming/errors.py b/dagshub/streaming/errors.py index bbfc773b..4699905b 100644 --- a/dagshub/streaming/errors.py +++ b/dagshub/streaming/errors.py @@ -11,5 +11,5 @@ def __str__(self): return ( f"There is already a filesystem mounted at path {self.path.absolute()} " f"({self.repo} revision {self.revision})" - f"\ndel() the filesystem object in use if you want to switch the mounted filesystem" + f"\nrun uninstall_hooks({self.path.absolute()}) to remove the existing hook" ) From 973a4e6f1e33a1dcc35b4865d9a9c55a3b54aa26 Mon Sep 17 00:00:00 2001 From: KBolashev Date: Thu, 14 Mar 2024 14:34:22 +0200 Subject: [PATCH 07/13] tenacity fix --- dagshub/common/api/repo.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dagshub/common/api/repo.py b/dagshub/common/api/repo.py index 45b7ac16..8cd1a8e3 100644 --- a/dagshub/common/api/repo.py +++ b/dagshub/common/api/repo.py @@ -4,7 +4,7 @@ import rich.progress from httpx import Response -from tenacity import retry_if_result, stop_after_attempt, wait_exponential, before_sleep_log, retry, retry_if_exception +from tenacity import stop_after_attempt, wait_exponential, before_sleep_log, retry, retry_if_exception from dagshub.common.api.responses import ( RepoAPIResponse, @@ -237,7 +237,7 @@ def _get(): return entries @retry( - retry=retry_if_result(_is_server_error_exception), + retry=retry_if_exception(_is_server_error_exception), stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), before_sleep=before_sleep_log(logger, logging.WARNING), @@ -264,7 +264,7 @@ def get_file(self, path: str, revision: Optional[str] = None) -> bytes: return res.content @retry( - retry=retry_if_result(_is_server_error_exception), + retry=retry_if_exception(_is_server_error_exception), stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), before_sleep=before_sleep_log(logger, logging.WARNING), From fb8dc04b0d7071c55d4128636260826288f6d692 Mon Sep 17 00:00:00 2001 From: KBolashev Date: Mon, 25 Mar 2024 15:09:24 +0200 Subject: [PATCH 08/13] Add a message explaining what the hooks are doing to the user --- dagshub/common/helpers.py | 2 +- dagshub/streaming/hook_router.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/dagshub/common/helpers.py b/dagshub/common/helpers.py index e031ec80..e863cfa7 100644 --- a/dagshub/common/helpers.py +++ b/dagshub/common/helpers.py @@ -97,7 +97,7 @@ def prompt_user(prompt, default=False) -> bool: return prompt_response == "y" -def log_message(msg, logger=None): +def log_message(msg: str, logger=None): """ Logs message to the info of the logger + prints, unless the printing was suppressed """ diff --git a/dagshub/streaming/hook_router.py b/dagshub/streaming/hook_router.py index 9cd3cd19..af43d129 100644 --- a/dagshub/streaming/hook_router.py +++ b/dagshub/streaming/hook_router.py @@ -8,6 +8,7 @@ from typing import Union, Optional, Callable, Dict, List, TYPE_CHECKING, Set from dagshub.common import is_inside_notebook, is_inside_colab +from dagshub.common.helpers import log_message from dagshub.streaming.dataclasses import DagshubPath, PathType from dagshub.streaming.errors import FilesystemAlreadyMountedError from dagshub.streaming.filesystem import DagshubScandirIterator @@ -324,6 +325,13 @@ def hook_repo(cls, fs: "DagsHubFilesystem", frameworks: Optional[List[str]]): existing_fs.project_root, existing_fs.repo_api.full_name, existing_fs.current_revision ) + msg = ( + f'Repository "{fs.repo_api.full_name}" is now hooked at path "{fs.project_root}".\n' + f"Any calls to Python file access function like open() and listdir() inside " + f"of this directory will include results from the repository." + ) + log_message(msg, logger) + cls.active_filesystems.add(fs) @classmethod From 033211b24ea585b1fa6d95a7f8668c8ca9c664be Mon Sep 17 00:00:00 2001 From: KBolashev Date: Tue, 26 Mar 2024 16:17:24 +0200 Subject: [PATCH 09/13] Explain the ignored rules --- ruff.toml | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/ruff.toml b/ruff.toml index e414d1ae..74699aeb 100644 --- a/ruff.toml +++ b/ruff.toml @@ -2,4 +2,11 @@ line-length=120 [lint] select = ["E", "F"] -ignore = ["E111", "E203", "E114", "E117", "E701"] +# Ignore some of the rules that are conflicting with the formatter +ignore = [ + "E111", # Indentation is not a multiple of four + "E203", # Whitespace before ':' + "E114", # Indentation is not a multiple of four (comment) + "E117", # Over-indented + "E701", # Multiple statements on one line (colon) +] From 56e089b657704ed80302864b825b54cdd3e53763 Mon Sep 17 00:00:00 2001 From: KBolashev Date: Tue, 26 Mar 2024 16:51:02 +0200 Subject: [PATCH 10/13] Fully qualify the uninstall_hooks function in the error --- dagshub/streaming/errors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dagshub/streaming/errors.py b/dagshub/streaming/errors.py index 4699905b..92063a41 100644 --- a/dagshub/streaming/errors.py +++ b/dagshub/streaming/errors.py @@ -11,5 +11,5 @@ def __str__(self): return ( f"There is already a filesystem mounted at path {self.path.absolute()} " f"({self.repo} revision {self.revision})" - f"\nrun uninstall_hooks({self.path.absolute()}) to remove the existing hook" + f"\nrun dagshub.streaming.uninstall_hooks({self.path.absolute()}) to remove the existing hook" ) From ae90a4347194a00d38f1c56292c7f408cb791ac3 Mon Sep 17 00:00:00 2001 From: KBolashev Date: Mon, 1 Apr 2024 11:20:14 +0300 Subject: [PATCH 11/13] Ignore 6xx+ status codes in is_server_error_exception --- dagshub/common/api/repo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dagshub/common/api/repo.py b/dagshub/common/api/repo.py index 8cd1a8e3..ddb3018c 100644 --- a/dagshub/common/api/repo.py +++ b/dagshub/common/api/repo.py @@ -69,7 +69,7 @@ def __str__(self): def _is_server_error_exception(exception: BaseException) -> bool: if not isinstance(exception, DagsHubHTTPError): return False - return exception.response.status_code >= 500 + return 500 <= exception.response.status_code < 600 class RepoAPI: From df087b25f8a551d0a6e8a00ba99872d5ef883eaf Mon Sep 17 00:00:00 2001 From: KBolashev Date: Mon, 1 Apr 2024 12:05:25 +0300 Subject: [PATCH 12/13] Retry: reuse between functions, add jitter, make the max wait 5 minutes --- dagshub/common/api/repo.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/dagshub/common/api/repo.py b/dagshub/common/api/repo.py index ddb3018c..4c0a32a6 100644 --- a/dagshub/common/api/repo.py +++ b/dagshub/common/api/repo.py @@ -1,10 +1,11 @@ import logging +from functools import partial from os import PathLike from pathlib import Path, PurePosixPath import rich.progress from httpx import Response -from tenacity import stop_after_attempt, wait_exponential, before_sleep_log, retry, retry_if_exception +from tenacity import before_sleep_log, retry, retry_if_exception, wait_exponential_jitter, stop_after_delay from dagshub.common.api.responses import ( RepoAPIResponse, @@ -17,7 +18,6 @@ from dagshub.common.download import download_files from dagshub.common.rich_util import get_rich_progress from dagshub.common.util import multi_urljoin -from functools import partial try: from functools import cached_property @@ -72,6 +72,15 @@ def _is_server_error_exception(exception: BaseException) -> bool: return 500 <= exception.response.status_code < 600 +def request_retry(func): + return retry( + retry=retry_if_exception(_is_server_error_exception), + stop=stop_after_delay(5 * 60), + wait=wait_exponential_jitter(max=60, jitter=2), + before_sleep=before_sleep_log(logger, logging.WARNING), + )(func) + + class RepoAPI: def __init__(self, repo: str, host: Optional[str] = None, auth: Optional[Any] = None): """ @@ -236,12 +245,7 @@ def _get(): return entries - @retry( - retry=retry_if_exception(_is_server_error_exception), - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - before_sleep=before_sleep_log(logger, logging.WARNING), - ) + @request_retry def get_file(self, path: str, revision: Optional[str] = None) -> bytes: """ Download file from repo. @@ -263,12 +267,7 @@ def get_file(self, path: str, revision: Optional[str] = None) -> bytes: raise DagsHubHTTPError(error_msg, res) return res.content - @retry( - retry=retry_if_exception(_is_server_error_exception), - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - before_sleep=before_sleep_log(logger, logging.WARNING), - ) + @request_retry def get_storage_file(self, path: str) -> bytes: """ Download file from a connected storage bucket. From 3d63f508c7274ba5104219c4607e10017d4ebab2 Mon Sep 17 00:00:00 2001 From: KBolashev Date: Mon, 1 Apr 2024 12:11:46 +0300 Subject: [PATCH 13/13] Carry over original path in DagshubPath as-is without turning it into Path --- dagshub/streaming/dataclasses.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/dagshub/streaming/dataclasses.py b/dagshub/streaming/dataclasses.py index b97a6dc5..b80b3890 100644 --- a/dagshub/streaming/dataclasses.py +++ b/dagshub/streaming/dataclasses.py @@ -35,16 +35,18 @@ def __init__(self, fs: "DagsHubFilesystem", file_path: Union[str, bytes, PathLik self.is_binary_path_requested = isinstance(file_path, bytes) self.absolute_path, self.relative_path, self.original_path = self.parse_path(file_path) - def parse_path(self, file_path: Union[str, bytes, PathLike, "DagshubPath"]) -> Tuple[Path, Optional[Path], Path]: + def parse_path( + self, file_path: Union[str, bytes, PathLike, "DagshubPath"] + ) -> Tuple[Path, Optional[Path], Union[str, bytes, PathLike]]: if isinstance(file_path, DagshubPath): self.is_binary_path_requested = file_path.is_binary_path_requested if file_path.fs != self.fs: relativized = DagshubPath(self.fs, file_path.absolute_path) return relativized.absolute_path, relativized.relative_path, relativized.original_path return file_path.absolute_path, file_path.relative_path, file_path.original_path + orig_path = file_path if isinstance(file_path, bytes): file_path = os.fsdecode(file_path) - orig_path = Path(file_path) abspath = Path(os.path.abspath(file_path)) try: relpath = abspath.relative_to(os.path.abspath(self.fs.project_root)) @@ -105,7 +107,7 @@ def is_passthrough_path(self, fs: "DagsHubFilesystem"): def __truediv__(self, other): new = DagshubPath( self.fs, - self.original_path / other, + Path(self.original_path) / other, ) new.is_binary_path_requested = self.is_binary_path_requested return new