Skip to content

Commit 88878d9

Browse files
committed
feat: add alternative download methods to resolver API
Extend the resolver API with alternative download URLs. Resolvers can now return download links to alternative locations or retrieval methods. The `PyPIProvider` now accepts a `override_download_url` parameter. The value overwrites the default PyPI download link. The string can contain a `{version}` format variable. The GitHub and Gitlab tag providers can return git clone URLs for `https` and `ssh` transport. The URLs uses pip's VCS syntax like `git+https://host/repo.git@tag`. The new enum `RetrieveMethod` has a `from_url()` constructor that parses an URL and splits it into method, url, and git ref. Signed-off-by: Christian Heimes <cheimes@redhat.com>
1 parent 3a9bfec commit 88878d9

File tree

2 files changed

+226
-8
lines changed

2 files changed

+226
-8
lines changed

src/fromager/resolver.py

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from __future__ import annotations
77

88
import datetime
9+
import enum
910
import functools
1011
import logging
1112
import os
@@ -14,7 +15,7 @@
1415
from collections.abc import Iterable
1516
from operator import attrgetter
1617
from platform import python_version
17-
from urllib.parse import quote, unquote, urljoin, urlparse
18+
from urllib.parse import quote, unquote, urljoin, urlparse, urlsplit, urlunsplit
1819

1920
import pypi_simple
2021
import resolvelib
@@ -180,11 +181,42 @@ def resolve_from_provider(
180181
raise ValueError(f"Unable to resolve {req}")
181182

182183

184+
class RetrieveMethod(enum.StrEnum):
185+
tarball = "tarball"
186+
git_https = "git+https"
187+
git_ssh = "git+ssh"
188+
189+
@classmethod
190+
def from_url(cls, download_url) -> tuple[RetrieveMethod, str, str | None]:
191+
"""Parse a download URL into method, url, reference"""
192+
scheme, netloc, path, query, fragment = urlsplit(
193+
download_url, allow_fragments=False
194+
)
195+
match scheme:
196+
case "https":
197+
return RetrieveMethod.tarball, download_url, None
198+
case "git+https":
199+
method = RetrieveMethod.git_https
200+
case "git+ssh":
201+
method = RetrieveMethod.git_ssh
202+
case _:
203+
raise ValueError(f"unsupported download URL {download_url!r}")
204+
# remove git+
205+
scheme = scheme[4:]
206+
# split off @ revision
207+
if "@" not in path:
208+
raise ValueError(f"git download url {download_url!r} is missing '@ref'")
209+
path, ref = path.rsplit("@", 1)
210+
return method, urlunsplit((scheme, netloc, path, query, fragment)), ref
211+
212+
183213
def get_project_from_pypi(
184214
project: str,
185215
extras: typing.Iterable[str],
186216
sdist_server_url: str,
187217
ignore_platform: bool = False,
218+
*,
219+
override_download_url: str | None = None,
188220
) -> Candidates:
189221
"""Return candidates created from the project name and extras."""
190222
found_candidates: set[str] = set()
@@ -345,14 +377,19 @@ def get_project_from_pypi(
345377
ignored_candidates.add(dp.filename)
346378
continue
347379

380+
if override_download_url is None:
381+
url = dp.url
382+
else:
383+
url = override_download_url.format(version=version)
384+
348385
upload_time = dp.upload_time
349386
if upload_time is not None:
350387
upload_time = upload_time.astimezone(datetime.UTC)
351388

352389
c = Candidate(
353390
name=name,
354391
version=version,
355-
url=dp.url,
392+
url=url,
356393
extras=tuple(sorted(extras)),
357394
is_sdist=is_sdist,
358395
build_tag=build_tag,
@@ -603,6 +640,7 @@ def __init__(
603640
ignore_platform: bool = False,
604641
*,
605642
use_resolver_cache: bool = True,
643+
override_download_url: str | None = None,
606644
):
607645
super().__init__(
608646
constraints=constraints,
@@ -613,6 +651,7 @@ def __init__(
613651
self.include_wheels = include_wheels
614652
self.sdist_server_url = sdist_server_url
615653
self.ignore_platform = ignore_platform
654+
self.override_download_url = override_download_url
616655

617656
@property
618657
def cache_key(self) -> str:
@@ -625,9 +664,10 @@ def cache_key(self) -> str:
625664
def find_candidates(self, identifier: str) -> Candidates:
626665
return get_project_from_pypi(
627666
identifier,
628-
set(),
629-
self.sdist_server_url,
630-
self.ignore_platform,
667+
extras=set(),
668+
sdist_server_url=self.sdist_server_url,
669+
ignore_platform=self.ignore_platform,
670+
override_download_url=self.override_download_url,
631671
)
632672

633673
def validate_candidate(
@@ -803,6 +843,7 @@ def __init__(
803843
*,
804844
req_type: RequirementType | None = None,
805845
use_resolver_cache: bool = True,
846+
retrieve_method: RetrieveMethod = RetrieveMethod.tarball,
806847
):
807848
super().__init__(
808849
constraints=constraints,
@@ -813,6 +854,7 @@ def __init__(
813854
)
814855
self.organization = organization
815856
self.repo = repo
857+
self.retrieve_method = retrieve_method
816858

817859
@property
818860
def cache_key(self) -> str:
@@ -847,7 +889,14 @@ def _find_tags(
847889
logger.debug(f"{identifier}: match function ignores {tagname}")
848890
continue
849891
assert isinstance(version, Version)
850-
url = entry["tarball_url"]
892+
893+
match self.retrieve_method:
894+
case RetrieveMethod.tarball:
895+
url = entry["tarball_url"]
896+
case RetrieveMethod.git_https:
897+
url = f"git+https://{self.host}/{self.organization}/{self.repo}.git@{tagname}"
898+
case RetrieveMethod.git_ssh:
899+
url = f"git+ssh://git@{self.host}/{self.organization}/{self.repo}.git@{tagname}"
851900

852901
# Github tag API endpoint does not include commit date information.
853902
# It would be too expensive to query every commit API endpoint.
@@ -880,6 +929,7 @@ def __init__(
880929
*,
881930
req_type: RequirementType | None = None,
882931
use_resolver_cache: bool = True,
932+
retrieve_method: RetrieveMethod = RetrieveMethod.tarball,
883933
) -> None:
884934
super().__init__(
885935
constraints=constraints,
@@ -889,6 +939,9 @@ def __init__(
889939
matcher=matcher,
890940
)
891941
self.server_url = server_url.rstrip("/")
942+
self.server_hostname = urlparse(server_url).hostname
943+
if not self.server_hostname:
944+
raise ValueError(f"invalid {server_url=}")
892945
self.project_path = project_path.lstrip("/")
893946
# URL-encode the project path as required by GitLab API.
894947
# The safe="" parameter tells quote() to encode ALL characters,
@@ -899,6 +952,7 @@ def __init__(
899952
self.api_url = (
900953
f"{self.server_url}/api/v4/projects/{encoded_path}/repository/tags"
901954
)
955+
self.retrieve_method = retrieve_method
902956

903957
@property
904958
def cache_key(self) -> str:
@@ -927,8 +981,14 @@ def _find_tags(
927981
continue
928982
assert isinstance(version, Version)
929983

930-
archive_path: str = f"{self.project_path}/-/archive/{tagname}/{self.project_path.split('/')[-1]}-{tagname}.tar.gz"
931-
url = urljoin(self.server_url, archive_path)
984+
match self.retrieve_method:
985+
case RetrieveMethod.tarball:
986+
archive_path: str = f"{self.project_path}/-/archive/{tagname}/{self.project_path.split('/')[-1]}-{tagname}.tar.gz"
987+
url = urljoin(self.server_url, archive_path)
988+
case RetrieveMethod.git_https:
989+
url = f"git+https://{self.server_hostname}/{self.project_path}.git@{tagname}"
990+
case RetrieveMethod.git_ssh:
991+
url = f"git+ssh://git@{self.server_hostname}/{self.project_path}.git@{tagname}"
932992

933993
# get tag creation time, fall back to commit creation time
934994
created_at_str: str | None = entry.get("created_at")

tests/test_resolver.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,26 @@ def test_provider_constraint_match() -> None:
370370
assert str(candidate.version) == "1.2.2"
371371

372372

373+
def test_provider_override_download_url() -> None:
374+
with requests_mock.Mocker() as r:
375+
r.get(
376+
"https://pypi.org/simple/hydra-core/",
377+
text=_hydra_core_simple_response,
378+
)
379+
380+
provider = resolver.PyPIProvider(
381+
override_download_url="https://server.test/hydr_core-{version}.tar.gz"
382+
)
383+
reporter: resolvelib.BaseReporter = resolvelib.BaseReporter()
384+
rslvr = resolvelib.Resolver(provider, reporter)
385+
386+
result = rslvr.resolve([Requirement("hydra-core")])
387+
assert "hydra-core" in result.mapping
388+
389+
candidate = result.mapping["hydra-core"]
390+
assert candidate.url == "https://server.test/hydr_core-1.3.2.tar.gz"
391+
392+
373393
_ignore_platform_simple_response = """
374394
<!DOCTYPE html>
375395
<html>
@@ -715,6 +735,51 @@ def test_resolve_github() -> None:
715735
)
716736

717737

738+
@pytest.mark.parametrize(
739+
["retrieve_method", "expected_url"],
740+
[
741+
(
742+
resolver.RetrieveMethod.tarball,
743+
"https://api.github.com/repos/python-wheel-build/fromager/tarball/refs/tags/0.9.0",
744+
),
745+
(
746+
resolver.RetrieveMethod.git_https,
747+
"git+https://github.com:443/python-wheel-build/fromager.git@0.9.0",
748+
),
749+
(
750+
resolver.RetrieveMethod.git_ssh,
751+
"git+ssh://git@github.com:443/python-wheel-build/fromager.git@0.9.0",
752+
),
753+
],
754+
)
755+
def test_resolve_github_retrieve_method(
756+
retrieve_method: resolver.RetrieveMethod, expected_url: str
757+
) -> None:
758+
with requests_mock.Mocker() as r:
759+
r.get(
760+
"https://api.github.com:443/repos/python-wheel-build/fromager",
761+
text=_github_fromager_repo_response,
762+
)
763+
r.get(
764+
"https://api.github.com:443/repos/python-wheel-build/fromager/tags",
765+
text=_github_fromager_tag_response,
766+
)
767+
768+
provider = resolver.GitHubTagProvider(
769+
organization="python-wheel-build",
770+
repo="fromager",
771+
retrieve_method=retrieve_method,
772+
)
773+
reporter: resolvelib.BaseReporter = resolvelib.BaseReporter()
774+
rslvr = resolvelib.Resolver(provider, reporter)
775+
776+
result = rslvr.resolve([Requirement("fromager")])
777+
assert "fromager" in result.mapping
778+
779+
candidate = result.mapping["fromager"]
780+
assert candidate.url == expected_url
781+
782+
718783
def test_github_constraint_mismatch() -> None:
719784
constraint = constraints.Constraints()
720785
constraint.add_constraint("fromager>=1.0")
@@ -922,6 +987,49 @@ def test_resolve_gitlab() -> None:
922987
)
923988

924989

990+
@pytest.mark.parametrize(
991+
["retrieve_method", "expected_url"],
992+
[
993+
(
994+
resolver.RetrieveMethod.tarball,
995+
"https://gitlab.com/mirrors/github/decile-team/submodlib/-/archive/v0.0.3/submodlib-v0.0.3.tar.gz",
996+
),
997+
(
998+
resolver.RetrieveMethod.git_https,
999+
"git+https://gitlab.com/mirrors/github/decile-team/submodlib.git@v0.0.3",
1000+
),
1001+
(
1002+
resolver.RetrieveMethod.git_ssh,
1003+
"git+ssh://git@gitlab.com/mirrors/github/decile-team/submodlib.git@v0.0.3",
1004+
),
1005+
],
1006+
)
1007+
def test_resolve_gitlab_retrieve_method(
1008+
retrieve_method: resolver.RetrieveMethod, expected_url: str
1009+
) -> None:
1010+
with requests_mock.Mocker() as r:
1011+
r.get(
1012+
"https://gitlab.com/api/v4/projects/mirrors%2Fgithub%2Fdecile-team%2Fsubmodlib/repository/tags",
1013+
text=_gitlab_submodlib_repo_response,
1014+
)
1015+
1016+
provider = resolver.GitLabTagProvider(
1017+
project_path="mirrors/github/decile-team/submodlib",
1018+
server_url="https://gitlab.com",
1019+
retrieve_method=retrieve_method,
1020+
)
1021+
reporter: resolvelib.BaseReporter = resolvelib.BaseReporter()
1022+
rslvr = resolvelib.Resolver(provider, reporter)
1023+
1024+
result = rslvr.resolve([Requirement("submodlib")])
1025+
assert "submodlib" in result.mapping
1026+
1027+
candidate = result.mapping["submodlib"]
1028+
assert str(candidate.version) == "0.0.3"
1029+
1030+
assert candidate.url == expected_url
1031+
1032+
9251033
def test_gitlab_constraint_mismatch() -> None:
9261034
constraint = constraints.Constraints()
9271035
constraint.add_constraint("submodlib>=1.0")
@@ -1107,3 +1215,53 @@ def custom_resolver_provider(*args, **kwargs):
11071215
assert "pypi.org" not in error_message.lower(), (
11081216
f"Error message incorrectly mentions PyPI when using GitHub resolver: {error_message}"
11091217
)
1218+
1219+
1220+
@pytest.mark.parametrize(
1221+
["download_url", "expected_method", "expected_url", "expected_ref"],
1222+
[
1223+
(
1224+
"https://api.github.com/repos/python-wheel-build/fromager/tarball/refs/tags/0.9.0",
1225+
resolver.RetrieveMethod.tarball,
1226+
"https://api.github.com/repos/python-wheel-build/fromager/tarball/refs/tags/0.9.0",
1227+
None,
1228+
),
1229+
(
1230+
"git+https://github.com:443/python-wheel-build/fromager.git@0.9.0",
1231+
resolver.RetrieveMethod.git_https,
1232+
"https://github.com:443/python-wheel-build/fromager.git",
1233+
"0.9.0",
1234+
),
1235+
(
1236+
"git+ssh://git@github.com:443/python-wheel-build/fromager.git@0.9.0",
1237+
resolver.RetrieveMethod.git_ssh,
1238+
"ssh://git@github.com:443/python-wheel-build/fromager.git",
1239+
"0.9.0",
1240+
),
1241+
],
1242+
)
1243+
def test_retrieve_method_from_url(
1244+
download_url: str,
1245+
expected_method: resolver.RetrieveMethod,
1246+
expected_url: str,
1247+
expected_ref: str | None,
1248+
) -> None:
1249+
assert resolver.RetrieveMethod.from_url(download_url) == (
1250+
expected_method,
1251+
expected_url,
1252+
expected_ref,
1253+
)
1254+
1255+
1256+
@pytest.mark.parametrize(
1257+
["download_url"],
1258+
[
1259+
["http://insecure.test"],
1260+
["hg+ssh://mercurial.test"],
1261+
],
1262+
)
1263+
def test_retrieve_method_from_url_error(
1264+
download_url: str,
1265+
) -> None:
1266+
with pytest.raises(ValueError):
1267+
resolver.RetrieveMethod.from_url(download_url)

0 commit comments

Comments
 (0)