Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 39 additions & 10 deletions dagshub/common/api/repo.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import logging
from functools import partial
from os import PathLike
from pathlib import Path, PurePosixPath

import rich.progress
from httpx import Response
from tenacity import before_sleep_log, retry, retry_if_exception, wait_exponential_jitter, stop_after_delay

from dagshub.common.api.responses import (
RepoAPIResponse,
Expand All @@ -15,7 +18,6 @@
from dagshub.common.download import download_files
from dagshub.common.rich_util import get_rich_progress
from dagshub.common.util import multi_urljoin
from functools import partial

try:
from functools import cached_property
Expand Down Expand Up @@ -54,6 +56,31 @@ class PathNotFoundError(Exception):
pass


class DagsHubHTTPError(Exception):
def __init__(self, msg: str, response: Response):
super().__init__()
self.msg = msg
self.response = response

def __str__(self):
return self.msg


def _is_server_error_exception(exception: BaseException) -> bool:
if not isinstance(exception, DagsHubHTTPError):
return False
return 500 <= exception.response.status_code < 600


def request_retry(func):
return retry(
retry=retry_if_exception(_is_server_error_exception),
stop=stop_after_delay(5 * 60),
wait=wait_exponential_jitter(max=60, jitter=2),
before_sleep=before_sleep_log(logger, logging.WARNING),
)(func)


class RepoAPI:
def __init__(self, repo: str, host: Optional[str] = None, auth: Optional[Any] = None):
"""
Expand Down Expand Up @@ -89,7 +116,7 @@ def get_repo_info(self) -> RepoAPIResponse:
error_msg = f"Got status code {res.status_code} when getting repository info."
logger.error(error_msg)
logger.debug(res.content)
raise RuntimeError(error_msg)
raise DagsHubHTTPError(error_msg, res)
return dacite.from_dict(RepoAPIResponse, res.json())

def get_branch_info(self, branch: str) -> BranchAPIResponse:
Expand All @@ -107,7 +134,7 @@ def get_branch_info(self, branch: str) -> BranchAPIResponse:
error_msg = f"Got status code {res.status_code} when getting branch."
logger.error(error_msg)
logger.debug(res.content)
raise RuntimeError(error_msg)
raise DagsHubHTTPError(error_msg, res)

return dacite.from_dict(BranchAPIResponse, res.json())

Expand All @@ -126,7 +153,7 @@ def get_commit_info(self, sha: str) -> CommitAPIResponse:
error_msg = f"Got status code {res.status_code} when getting commit."
logger.error(error_msg)
logger.debug(res.content)
raise RuntimeError(error_msg)
raise DagsHubHTTPError(error_msg, res)

return dacite.from_dict(CommitAPIResponse, res.json()["commit"])

Expand All @@ -142,7 +169,7 @@ def get_connected_storages(self) -> List[StorageAPIEntry]:
error_msg = f"Got status code {res.status_code} when getting repository info."
logger.error(error_msg)
logger.debug(res.content)
raise RuntimeError(error_msg)
raise DagsHubHTTPError(error_msg, res)

return [dacite.from_dict(StorageAPIEntry, storage_entry) for storage_entry in res.json()]

Expand All @@ -164,7 +191,7 @@ def list_path(self, path: str, revision: Optional[str] = None, include_size: boo
error_msg = f"Got status code {res.status_code} when listing path {path}"
logger.error(error_msg)
logger.debug(res.content)
raise RuntimeError(error_msg)
raise DagsHubHTTPError(error_msg, res)

content = res.json()
if type(content) is dict:
Expand Down Expand Up @@ -194,7 +221,7 @@ def _get():
error_msg = f"Got status code {res.status_code} when listing path {path}"
logger.error(error_msg)
logger.debug(res.content)
raise RuntimeError(error_msg)
raise DagsHubHTTPError(error_msg, res)

content = res.json()
if "entries" not in content:
Expand All @@ -218,6 +245,7 @@ def _get():

return entries

@request_retry
def get_file(self, path: str, revision: Optional[str] = None) -> bytes:
"""
Download file from repo.
Expand All @@ -229,16 +257,17 @@ def get_file(self, path: str, revision: Optional[str] = None) -> bytes:
Returns:
bytes: The content of the file.
"""
res = self._http_request("GET", self.raw_api_url(path, revision))
res = self._http_request("GET", self.raw_api_url(path, revision), timeout=None)
if res.status_code == 404:
raise PathNotFoundError(f"Path {path} not found")
elif res.status_code >= 400:
error_msg = f"Got status code {res.status_code} when getting file {path}"
logger.error(error_msg)
logger.debug(res.content)
raise RuntimeError(error_msg)
raise DagsHubHTTPError(error_msg, res)
return res.content

@request_retry
def get_storage_file(self, path: str) -> bytes:
"""
Download file from a connected storage bucket.
Expand All @@ -258,7 +287,7 @@ def get_storage_file(self, path: str) -> bytes:
error_msg = f"Got status code {res.status_code} when getting file {path}"
logger.error(error_msg)
logger.debug(res.content)
raise RuntimeError(error_msg)
raise DagsHubHTTPError(error_msg, res)
return res.content

def _get_files_in_path(
Expand Down
2 changes: 1 addition & 1 deletion dagshub/common/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def prompt_user(prompt, default=False) -> bool:
return prompt_response == "y"


def log_message(msg, logger=None):
def log_message(msg: str, logger=None):
"""
Logs message to the info of the logger + prints, unless the printing was suppressed
"""
Expand Down
10 changes: 8 additions & 2 deletions dagshub/streaming/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .filesystem import DagsHubFilesystem, install_hooks, uninstall_hooks
from .filesystem import DagsHubFilesystem, install_hooks, uninstall_hooks, get_mounted_filesystems

try:
from .mount import mount
Expand All @@ -15,4 +15,10 @@ def mount(*args, **kwargs):
print(error)


__all__ = [DagsHubFilesystem.__name__, install_hooks.__name__, mount.__name__, uninstall_hooks.__name__]
__all__ = [
DagsHubFilesystem.__name__,
install_hooks.__name__,
mount.__name__,
uninstall_hooks.__name__,
get_mounted_filesystems.__name__,
]
Loading