diff --git a/fvcore/common/download.py b/fvcore/common/download.py index f079a8b..bbb2230 100644 --- a/fvcore/common/download.py +++ b/fvcore/common/download.py @@ -7,6 +7,48 @@ from urllib import request +def dump_url_to_file( + url: str, filepath: str, progress: bool = True, desc: str = None +) -> str: + """ + Download a file from a given URL to a directory. If file exists, will + overwrite the existing file. + + Args: + url (str): + filepath (str): the path to save the file. + The directory is assumed to already exist. + progress (bool): whether to use tqdm to draw a progress bar. + desc (bool): desc to pass to tqdm if drawing a progress bar. + + Returns: + filepath (str): the path to the downloaded file. + This is always identical to the filepath argument. + """ + if progress: + import tqdm + + def hook(t: tqdm.tqdm) -> Callable[[int, int, Optional[int]], None]: + last_b: List[int] = [0] + + def inner(b: int, bsize: int, tsize: Optional[int] = None) -> None: + if tsize is not None: + t.total = tsize + t.update((b - last_b[0]) * bsize) # type: ignore + last_b[0] = b + + return inner + + with tqdm.tqdm( # type: ignore + unit="B", unit_scale=True, miniters=1, desc=desc, leave=True + ) as t: + tmp, _ = request.urlretrieve(url, filename=filepath, reporthook=hook(t)) + + else: + tmp, _ = request.urlretrieve(url, filename=filepath) + return tmp + + def download( url: str, dir: str, *, filename: Optional[str] = None, progress: bool = True ) -> str: @@ -38,27 +80,7 @@ def download( tmp = fpath + ".tmp" # download to a tmp file first, to be more atomic. try: logger.info("Downloading from {} ...".format(url)) - if progress: - import tqdm - - def hook(t: tqdm.tqdm) -> Callable[[int, int, Optional[int]], None]: - last_b: List[int] = [0] - - def inner(b: int, bsize: int, tsize: Optional[int] = None) -> None: - if tsize is not None: - t.total = tsize - t.update((b - last_b[0]) * bsize) # type: ignore - last_b[0] = b - - return inner - - with tqdm.tqdm( # type: ignore - unit="B", unit_scale=True, miniters=1, desc=filename, leave=True - ) as t: - tmp, _ = request.urlretrieve(url, filename=tmp, reporthook=hook(t)) - - else: - tmp, _ = request.urlretrieve(url, filename=tmp) + tmp = dump_url_to_file(url, filepath=tmp, progress=progress, desc=filename) statinfo = os.stat(tmp) size = statinfo.st_size if size == 0: