Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion bilix/download/base_downloader_part.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

__all__ = ['BaseDownloaderPart']

domain_scores = {}


class BaseDownloaderPart(BaseDownloader):
"""Base Async http Content-Range Downloader"""
Expand Down Expand Up @@ -199,6 +201,7 @@ async def get_file(self, url_or_urls: Union[str, Iterable[str]], path: Union[Pat

async def _get_file_part(self, urls: List[str], path: Path, part_range: Tuple[int, int],
task_id) -> Path:
global domain_scores
start, end = part_range
part_path = path.with_name(f'{path.name}.{part_range[0]}-{part_range[1]}')
exist, part_path = path_check(part_path)
Expand All @@ -208,10 +211,20 @@ async def _get_file_part(self, urls: List[str], path: Path, part_range: Tuple[in
await self.progress.update(task_id, advance=downloaded)
if start > end:
return part_path # skip already finished
url_idx = random.randint(0, len(urls) - 1)

for times in range(1 + self.stream_retry):
# find domain with min score in domain_scores
domains = [urlparse(url_).netloc for url_ in urls]
domain_idx = min(range(len(domains)), key=lambda i: domain_scores.get(domains[i], 0))
url_idx = domain_idx
# update domain_scores
domain_min_score = domain_scores.get(domains[domain_idx], 0)
for domain in domain_scores.keys():
domain_scores[domain] = domain_scores[domain] - domain_min_score

try:
# parse domain from url
domain = urlparse(urls[url_idx]).netloc
async with \
self.client.stream("GET", urls[url_idx], follow_redirects=True,
headers={'Range': f'bytes={start}-{end}'}) as r, \
Expand All @@ -227,6 +240,7 @@ async def _get_file_part(self, urls: List[str], path: Path, part_range: Tuple[in
await self._check_speed(len(chunk))
break
except (httpx.HTTPStatusError, httpx.TransportError):
domain_scores[domain] = domain_scores.get(domain, 0) + 1
continue
else:
raise Exception(f"STREAM 超过重复次数 {part_path.name}")
Expand Down