diff --git a/docs/async.md b/docs/async.md index 089d783191..6e3dad9e72 100644 --- a/docs/async.md +++ b/docs/async.md @@ -189,6 +189,26 @@ async def main(): anyio.run(main, backend='trio') ``` +## Handling Server Disconnects + +In rare cases where the `keep_alive` value of the destination is shorter than the client a `RemoteProtocolException` may be thrown. With `handle_disconnects` set to True a reconnection will be attempeted. If the reconnect is successful and `reduce_disconnects` is set to True it will attempt to reduce future disconencts by reducing the `keep_alive` value of the client. The factor at whcih the keep_alive is reduced can be set by setting reduce_timeout_factor + +```python +import httpx +import trio + +async def main(): + async with httpx.AsyncClient( + handle_disconnects=True, + reduce_disconnects=True, + reduce_timeout_factor=2 + ) as client: + response = await client.get('https://www.example.com/') + print(response) + +trio.run(main) +``` + ## Calling into Python Web Apps For details on calling directly into ASGI applications, see [the `ASGITransport` docs](../advanced/transports#asgitransport). \ No newline at end of file diff --git a/httpx/_client.py b/httpx/_client.py index 2249231f8c..038b1a5f20 100644 --- a/httpx/_client.py +++ b/httpx/_client.py @@ -1366,6 +1366,9 @@ def __init__( timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, follow_redirects: bool = False, limits: Limits = DEFAULT_LIMITS, + handle_disconnects: bool = True, + reduce_disconnects: bool = True, + reduce_timeout_factor: int = 2, max_redirects: int = DEFAULT_MAX_REDIRECTS, event_hooks: None | (typing.Mapping[str, list[EventHook]]) = None, base_url: URL | str = "", @@ -1407,6 +1410,9 @@ def __init__( http2=http2, limits=limits, transport=transport, + handle_disconnects=handle_disconnects, + reduce_disconnects=reduce_disconnects, + reduce_timeout_factor=reduce_timeout_factor, ) self._mounts: dict[URLPattern, AsyncBaseTransport | None] = { @@ -1438,6 +1444,9 @@ def _init_transport( http2: bool = False, limits: Limits = DEFAULT_LIMITS, transport: AsyncBaseTransport | None = None, + handle_disconnects: bool = True, + reduce_disconnects: bool = True, + reduce_timeout_factor: int = 2, ) -> AsyncBaseTransport: if transport is not None: return transport @@ -1449,6 +1458,9 @@ def _init_transport( http1=http1, http2=http2, limits=limits, + handle_disconnects=handle_disconnects, + reduce_disconnects=reduce_disconnects, + reduce_timeout_factor=reduce_timeout_factor, ) def _init_proxy_transport( diff --git a/httpx/_transports/default.py b/httpx/_transports/default.py index d5aa05ff23..009481db44 100644 --- a/httpx/_transports/default.py +++ b/httpx/_transports/default.py @@ -290,9 +290,16 @@ def __init__( local_address: str | None = None, retries: int = 0, socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + handle_disconnects: bool = True, + reduce_disconnects: bool = True, + reduce_timeout_factor: int = 2, ) -> None: import httpcore + self.handle_disconnects = handle_disconnects + self.reduce_disconnects = reduce_disconnects + self.reduce_timeout_factor = reduce_timeout_factor + proxy = Proxy(url=proxy) if isinstance(proxy, (str, URL)) else proxy ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env) @@ -390,8 +397,17 @@ async def handle_async_request( content=request.stream, extensions=request.extensions, ) - with map_httpcore_exceptions(): - resp = await self._pool.handle_async_request(req) + + try: + with map_httpcore_exceptions(): + resp = await self._pool.handle_async_request(req) + except RemoteProtocolError: + if not self.handle_disconnects: + raise + await self.areconnect() + with map_httpcore_exceptions(): + resp = await self._pool.handle_async_request(req) + print("Reconnection Attempt Successful") assert isinstance(resp.stream, typing.AsyncIterable) @@ -402,5 +418,17 @@ async def handle_async_request( extensions=resp.extensions, ) + async def areconnect(self) -> None: + await self._pool.aclose() + + if not self.reduce_disconnects or self._pool._keepalive_expiry is None: + return + print( + "Attempt to reduce future disconnects \ +by reducing timeout by a facotr of %d" + % self.reduce_timeout_factor + ) + self._pool._keepalive_expiry //= self.reduce_timeout_factor + async def aclose(self) -> None: await self._pool.aclose() diff --git a/test b/test new file mode 100644 index 0000000000..a7c01bc6a4 --- /dev/null +++ b/test @@ -0,0 +1 @@ +# TLS secrets log file, generated by OpenSSL / Python diff --git a/tests/client/test_async_client.py b/tests/client/test_async_client.py index 8d7eaa3c58..3317893f37 100644 --- a/tests/client/test_async_client.py +++ b/tests/client/test_async_client.py @@ -1,3 +1,4 @@ +""" from __future__ import annotations import typing @@ -6,6 +7,130 @@ import pytest import httpx +""" + +import typing +from datetime import timedelta +from unittest.mock import AsyncMock + +import pytest + +import httpx +from httpx._config import Limits +from httpx._transports.default import AsyncHTTPTransport + + +@pytest.mark.anyio +async def test_areconnect_reduce_disconnects_false(server): + """Test areconnect when reduce_disconnects is False.""" + transport = AsyncHTTPTransport( + http2=True, + reduce_disconnects=False, + limits=httpx.Limits( + keepalive_expiry=100, + ), + ) + transport._pool = AsyncMock() # Mock the pool + transport._pool.aclose = AsyncMock() + transport._pool._keepalive_expiry = 60.0 + + await transport.areconnect() + assert transport._pool._keepalive_expiry == 60.0 # Should remain unchanged + transport._pool.aclose.assert_called_once() + + +@pytest.mark.anyio +async def test_areconnect_keepalive_expiry_none(server): + """Test areconnect when keepalive_expiry is None.""" + limits = Limits(keepalive_expiry=None) + transport = AsyncHTTPTransport(http2=True, limits=limits) + transport._pool = AsyncMock() # Mock the pool + transport._pool.aclose = AsyncMock() + transport._pool._keepalive_expiry = None + + await transport.areconnect() + assert transport._pool._keepalive_expiry is None # Should remain None + transport._pool.aclose.assert_called_once() + + +@pytest.mark.anyio +async def test_aexit_exception_mapping(): + """Test that httpcore exceptions during __aexit__ are mapped.""" + import httpcore + + transport = AsyncHTTPTransport() + transport._pool = AsyncMock() + # Configure the mock to raise a specific httpcore exception. + transport._pool.__aexit__ = AsyncMock( + side_effect=httpcore.ConnectError("Mocked ConnectError") + ) + + with pytest.raises(httpx.ConnectError) as exc_info: + async with transport: + pass # The exception will occur during the 'async with' exit. + + assert "Mocked ConnectError" in str(exc_info.value) + + +@pytest.mark.anyio +async def test_remote_protocol_error_reconnect_handling_disabled(server): + """ + If we set the handle_disconnects parameter to false, it will not + attempt to recover from httpcore.RemoteProtocolError exceptions + """ + import httpcore + + transport = AsyncHTTPTransport(handle_disconnects=False) + transport._pool = AsyncMock() + transport._pool.handle_async_request = AsyncMock( + side_effect=httpcore.RemoteProtocolError("Mocked protocol error") + ) + + async with httpx.AsyncClient(transport=transport) as client: + with pytest.raises(httpx.RemoteProtocolError): + await client.get(server.url) + + +@pytest.mark.anyio +async def test_remote_protocol_error_successfull_reconnect(server): + """ + If httpcore.RemoteProtocolError is rised but reconnections are + set it will try to reconnect once and return normally if it's successful + """ + import httpcore + + transport = AsyncHTTPTransport() + transport._pool = AsyncMock() + transport._pool.handle_async_request = AsyncMock( + side_effect=[ + httpcore.RemoteProtocolError("Mocked protocol error"), + httpcore.Response(200), + ] + ) + async with httpx.AsyncClient(transport=transport) as client: + response = await client.get(server.url) + assert response.status_code == 200 + + +@pytest.mark.anyio +async def test_remote_protocol_error_failure_reconnect(server): + """ + If httpcore.RemoteProtocolError is rised but reconnections are + set it will try to reconnect once and return raised exception on second failure + """ + import httpcore + + transport = AsyncHTTPTransport() + transport._pool = AsyncMock() + transport._pool.handle_async_request = AsyncMock( + side_effect=[ + httpcore.RemoteProtocolError("Mocked protocol error"), + httpcore.RemoteProtocolError("Mocked protocol error"), + ] + ) + async with httpx.AsyncClient(transport=transport) as client: + with pytest.raises(httpx.RemoteProtocolError): + await client.get(server.url) @pytest.mark.anyio @@ -183,7 +308,7 @@ async def test_100_continue(server): async def test_context_managed_transport(): class Transport(httpx.AsyncBaseTransport): def __init__(self) -> None: - self.events: list[str] = [] + self.events: typing.List[str] = [] async def aclose(self): # The base implementation of httpx.AsyncBaseTransport just @@ -216,7 +341,7 @@ async def test_context_managed_transport_and_mount(): class Transport(httpx.AsyncBaseTransport): def __init__(self, name: str) -> None: self.name: str = name - self.events: list[str] = [] + self.events: typing.List[str] = [] async def aclose(self): # The base implementation of httpx.AsyncBaseTransport just