diff --git a/CHANGES/2596.bugfix.rst b/CHANGES/2596.bugfix.rst new file mode 100644 index 00000000000..e172506bcde --- /dev/null +++ b/CHANGES/2596.bugfix.rst @@ -0,0 +1,2 @@ +Fixed proxy authorization headers not being passed when reusing a connection, which caused 407 (Proxy authentication required) errors +-- by :user:`GLeurquin`. diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index ec2b86d1495..7f81d3e5dd6 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -151,6 +151,7 @@ Georges Dubus Greg Holt Gregory Haynes Grigoriy Soldatov +Guillaume Leurquin Gus Goulart Gustavo Carneiro Günther Jena diff --git a/aiohttp/connector.py b/aiohttp/connector.py index b1820358bae..2cdc425d83f 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -550,6 +550,30 @@ def _available_connections(self, key: "ConnectionKey") -> int: return total_remain + def _update_proxy_auth_header_and_build_proxy_req( + self, req: ClientRequest + ) -> ClientRequestBase: + """Set Proxy-Authorization header for non-SSL proxy requests and builds the proxy request for SSL proxy requests.""" + url = req.proxy + assert url is not None + headers = req.proxy_headers or CIMultiDict[str]() + headers[hdrs.HOST] = req.headers[hdrs.HOST] + proxy_req = ClientRequestBase( + hdrs.METH_GET, + url, + headers=headers, + auth=req.proxy_auth, + loop=self._loop, + ssl=req.ssl, + ) + auth = proxy_req.headers.pop(hdrs.AUTHORIZATION, None) + if auth is not None: + if not req.is_ssl(): + req.headers[hdrs.PROXY_AUTHORIZATION] = auth + else: + proxy_req.headers[hdrs.PROXY_AUTHORIZATION] = auth + return proxy_req + async def connect( self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout" ) -> Connection: @@ -558,12 +582,16 @@ async def connect( if (conn := await self._get(key, traces)) is not None: # If we do not have to wait and we can get a connection from the pool # we can avoid the timeout ceil logic and directly return the connection + if req.proxy: + self._update_proxy_auth_header_and_build_proxy_req(req) return conn async with ceil_timeout(timeout.connect, timeout.ceil_threshold): if self._available_connections(key) <= 0: await self._wait_for_available_connection(key, traces) if (conn := await self._get(key, traces)) is not None: + if req.proxy: + self._update_proxy_auth_header_and_build_proxy_req(req) return conn placeholder = cast( @@ -1453,32 +1481,13 @@ async def _create_direct_connection( async def _create_proxy_connection( self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout" ) -> tuple[asyncio.BaseTransport, ResponseHandler]: - headers = CIMultiDict[str]() if req.proxy_headers is None else req.proxy_headers - headers[hdrs.HOST] = req.headers[hdrs.HOST] - - url = req.proxy - assert url is not None - proxy_req = ClientRequestBase( - hdrs.METH_GET, - url, - headers=headers, - auth=req.proxy_auth, - loop=self._loop, - ssl=req.ssl, - ) + proxy_req = self._update_proxy_auth_header_and_build_proxy_req(req) # create connection to proxy server transport, proto = await self._create_direct_connection( proxy_req, [], timeout, client_error=ClientProxyConnectionError ) - auth = proxy_req.headers.pop(hdrs.AUTHORIZATION, None) - if auth is not None: - if not req.is_ssl(): - req.headers[hdrs.PROXY_AUTHORIZATION] = auth - else: - proxy_req.headers[hdrs.PROXY_AUTHORIZATION] = auth - if req.is_ssl(): self._warn_about_tls_in_tls(transport, req) diff --git a/tests/test_connector.py b/tests/test_connector.py index 1d5ed0c01a0..ae5e2e068b0 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -1,5 +1,6 @@ # Tests of http client with custom Connector import asyncio +import contextlib import gc import hashlib import platform @@ -16,6 +17,7 @@ from unittest import mock import pytest +from multidict import CIMultiDict from pytest_mock import MockerFixture from yarl import URL @@ -25,6 +27,7 @@ ClientSession, ClientTimeout, connector as connector_module, + hdrs, web, ) from aiohttp.abc import ResolveResult @@ -3299,6 +3302,93 @@ async def test_connect_reuseconn_tracing( await conn.close() +@pytest.mark.parametrize( + "test_case,wait_for_con,expect_proxy_auth_header", + [ + ("use_proxy_with_embedded_auth", False, True), + ("use_proxy_with_auth_headers", True, True), + ("use_proxy_no_auth", False, False), + ("dont_use_proxy", False, False), + ], +) +async def test_connect_reuse_proxy_headers( # type: ignore[misc] + loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, + test_case: str, + wait_for_con: bool, + expect_proxy_auth_header: bool, +) -> None: + proto = create_mocked_conn(loop) + proto.is_connected.return_value = True + + if test_case != "dont_use_proxy": + proxy = ( + URL("http://user:password@example.com") + if test_case == "use_proxy_with_embedded_auth" + else URL("http://example.com") + ) + proxy_headers = ( + CIMultiDict({hdrs.AUTHORIZATION: "Basic dXNlcjpwYXNzd29yZA=="}) + if test_case == "use_proxy_with_auth_headers" + else None + ) + else: + proxy = None + proxy_headers = None + key = ConnectionKey( + "localhost", + 80, + False, + True, + proxy, + None, + hash(tuple(proxy_headers.items())) if proxy_headers else None, + ) + req = make_client_request( + "GET", + URL("http://localhost:80"), + loop=loop, + response_class=mock.Mock(), + proxy=proxy, + proxy_headers=proxy_headers, + ) + + conn = aiohttp.BaseConnector(limit=1) + + async def _create_con(*args: Any, **kwargs: Any) -> None: + conn._conns[key] = deque([(proto, loop.time())]) + + with contextlib.ExitStack() as stack: + if wait_for_con: + # Simulate no available connections + stack.enter_context( + mock.patch.object( + conn, "_available_connections", autospec=True, return_value=0 + ) + ) + # Upon waiting for a connection, populate _conns with our proto, + # mocking a connection becoming immediately available + stack.enter_context( + mock.patch.object( + conn, + "_wait_for_available_connection", + autospec=True, + side_effect=_create_con, + ) + ) + else: + await _create_con() + # Call function to test + conn2 = await conn.connect(req, [], ClientTimeout()) + conn2.release() + await conn.close() + + if expect_proxy_auth_header: + assert req.headers[hdrs.PROXY_AUTHORIZATION] == "Basic dXNlcjpwYXNzd29yZA==" + else: + assert hdrs.PROXY_AUTHORIZATION not in req.headers + + async def test_connect_with_limit_and_limit_per_host( loop: asyncio.AbstractEventLoop, key: ConnectionKey,