From f517c21f9bf98d0369d621c575d4550af5c2d2c3 Mon Sep 17 00:00:00 2001 From: Alice Purcell Date: Mon, 14 Jul 2025 09:54:41 +0100 Subject: [PATCH 1/4] Support local SSH configuration Support local proxy environment settings and trust stores when connecting to GitHub. --- git_sync/github.py | 80 +++++++++++++++++++++++++++------------------- requirements.txt | 1 + 2 files changed, 48 insertions(+), 33 deletions(-) diff --git a/git_sync/github.py b/git_sync/github.py index 5f1cf78..d98ebf2 100644 --- a/git_sync/github.py +++ b/git_sync/github.py @@ -1,9 +1,12 @@ import re +import ssl from asyncio import Semaphore, gather from collections.abc import AsyncIterator, Callable, Iterable from dataclasses import dataclass from typing import TypeVar +import aiohttp +import truststore from aiographql.client import GraphQLClient # type: ignore[import-untyped] T = TypeVar("T") @@ -101,6 +104,12 @@ def join_queries(queries: Iterable[str]) -> str: return "{" + "\n".join(f"q{i}: {query}" for i, query in enumerate(queries)) + "}" +def client_session() -> aiohttp.ClientSession: + """Configure aiohttp to trust local SSL credentials and environment variables.""" + connector = aiohttp.TCPConnector(ssl=truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)) + return aiohttp.ClientSession(trust_env=True, connector=connector) + + async def fetch_pull_requests_from_domain( token: str, domain: str, repos: list[Repository] ) -> AsyncIterator[PullRequest]: @@ -109,42 +118,47 @@ async def fetch_pull_requests_from_domain( if domain.count(".") == 1 else f"https://{domain}/api/graphql" ) - client = GraphQLClient( - endpoint=endpoint, headers={"Authorization": f"Bearer {token}"} - ) - # Query for PRs and commit counts - initial_queries = [ - pr_initial_query(repo.owner, repo.name) for i, repo in enumerate(repos, 1) - ] - initial_response = await client.query(join_queries(initial_queries)) - assert not initial_response.errors - - # Determine what follow-up queries to make - details_queries = [ - pr_details_query(pr_data["id"], pr_data["commits"]["totalCount"]) - for repo_data in initial_response.data.values() - for pr_data in repo_data["pullRequests"]["nodes"] - ] - - # Query for detailed PR information - details_response = await client.query(join_queries(details_queries)) - assert not details_response.errors - - # Yield response data as PullRequest objects - for pr_data in details_response.data.values(): - head_repo = pr_data.get("headRepository") or {} - repo_urls = [head_repo.get("sshUrl"), head_repo.get("url")] - hashes = tuple( - commit["commit"]["oid"] for commit in reversed(pr_data["commits"]["nodes"]) - ) - yield PullRequest( - branch_name=pr_data["headRefName"], - repo_urls=frozenset(url for url in repo_urls if url is not None), - hashes=hashes, - merged_hash=(pr_data.get("mergeCommit") or {}).get("oid"), + async with client_session() as session: + client = GraphQLClient( + endpoint=endpoint, + headers={"Authorization": f"Bearer {token}"}, + session=session, ) + # Query for PRs and commit counts + initial_queries = [ + pr_initial_query(repo.owner, repo.name) for i, repo in enumerate(repos, 1) + ] + initial_response = await client.query(join_queries(initial_queries)) + assert not initial_response.errors + + # Determine what follow-up queries to make + details_queries = [ + pr_details_query(pr_data["id"], pr_data["commits"]["totalCount"]) + for repo_data in initial_response.data.values() + for pr_data in repo_data["pullRequests"]["nodes"] + ] + + # Query for detailed PR information + details_response = await client.query(join_queries(details_queries)) + assert not details_response.errors + + # Yield response data as PullRequest objects + for pr_data in details_response.data.values(): + head_repo = pr_data.get("headRepository") or {} + repo_urls = [head_repo.get("sshUrl"), head_repo.get("url")] + hashes = tuple( + commit["commit"]["oid"] + for commit in reversed(pr_data["commits"]["nodes"]) + ) + yield PullRequest( + branch_name=pr_data["headRefName"], + repo_urls=frozenset(url for url in repo_urls if url is not None), + hashes=hashes, + merged_hash=(pr_data.get("mergeCommit") or {}).get("oid"), + ) + async def fetch_pull_requests( tokens: Callable[[str], str | None], diff --git a/requirements.txt b/requirements.txt index d0268a9..4d3a478 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ aiographql-client >= 1.0.3 +truststore >= 0.10.1 From 12b415c20a47ff2ffca1cc4b686c38a1692ae71c Mon Sep 17 00:00:00 2001 From: Alice Purcell Date: Mon, 14 Jul 2025 12:31:48 +0100 Subject: [PATCH 2/4] Make pyright happy --- git_sync/__init__.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/git_sync/__init__.py b/git_sync/__init__.py index ba7598d..031acf2 100644 --- a/git_sync/__init__.py +++ b/git_sync/__init__.py @@ -97,8 +97,11 @@ async def git_sync() -> None: print("Error: Not in a git repository", file=sys.stderr) sys.exit(2) - if remote_urls: - pull_request_task = create_task(fetch_pull_requests(github_token, remote_urls)) + pull_request_task = ( + create_task(fetch_pull_requests(github_token, remote_urls)) + if remote_urls + else None + ) try: branches = await get_branches_with_remote_upstreams() @@ -111,7 +114,7 @@ async def git_sync() -> None: if push_remote: await fast_forward_to_downstream(push_remote, branches) - if remote_urls: + if pull_request_task is not None: pull_requests = await pull_request_task push_remote_url = next( remote.url for remote in remotes if remote.name == push_remote From ee67eefc45b934c795a1aece183bf74ffecc34c1 Mon Sep 17 00:00:00 2001 From: Alice Purcell Date: Mon, 14 Jul 2025 12:47:03 +0100 Subject: [PATCH 3/4] Support .git suffix on https URLs The GitHub web UI appends `.git` to the end of its https URLs; the GraphQL API does not. Git seems to work fine with either, so support both. --- git_sync/github.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/git_sync/github.py b/git_sync/github.py index d98ebf2..0045a07 100644 --- a/git_sync/github.py +++ b/git_sync/github.py @@ -1,9 +1,9 @@ import re import ssl from asyncio import Semaphore, gather -from collections.abc import AsyncIterator, Callable, Iterable +from collections.abc import AsyncIterator, Callable, Iterable, Iterator from dataclasses import dataclass -from typing import TypeVar +from typing import Any, TypeVar import aiohttp import truststore @@ -110,6 +110,15 @@ def client_session() -> aiohttp.ClientSession: return aiohttp.ClientSession(trust_env=True, connector=connector) +def repo_urls(pr_data: dict[str, Any]) -> Iterator[str]: + head_repo = pr_data.get("headRepository") or {} + if ssh_url := head_repo.get("sshUrl"): + yield ssh_url + if http_url := head_repo.get("url"): + yield http_url + yield http_url + ".git" + + async def fetch_pull_requests_from_domain( token: str, domain: str, repos: list[Repository] ) -> AsyncIterator[PullRequest]: @@ -146,15 +155,13 @@ async def fetch_pull_requests_from_domain( # Yield response data as PullRequest objects for pr_data in details_response.data.values(): - head_repo = pr_data.get("headRepository") or {} - repo_urls = [head_repo.get("sshUrl"), head_repo.get("url")] hashes = tuple( commit["commit"]["oid"] for commit in reversed(pr_data["commits"]["nodes"]) ) yield PullRequest( branch_name=pr_data["headRefName"], - repo_urls=frozenset(url for url in repo_urls if url is not None), + repo_urls=frozenset(repo_urls(pr_data)), hashes=hashes, merged_hash=(pr_data.get("mergeCommit") or {}).get("oid"), ) From 5bd9fbcf9fb3282d841eb6525720828bd73c3320 Mon Sep 17 00:00:00 2001 From: Alice Purcell Date: Mon, 14 Jul 2025 12:59:26 +0100 Subject: [PATCH 4/4] Point release --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 2c46e4c..ef17650 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "git-sync" -version = "0.4.1" +version = "0.4.2" description = "Synchronize local git repo with remotes" authors = [{ name = "Alice Purcell", email = "alicederyn@gmail.com" }] requires-python = ">= 3.12"