From 1e4120e87daec963c67f956111e6bca44d7c3dea Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sat, 3 Jan 2026 14:34:56 +0000 Subject: [PATCH 1/2] Limit number of chunks before pausing reading (#11894) Co-authored-by: J. Nick Koston --- aiohttp/streams.py | 25 ++++++- tests/test_streams.py | 170 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 194 insertions(+), 1 deletion(-) diff --git a/aiohttp/streams.py b/aiohttp/streams.py index 8c902d6b1ce..034fcc540c0 100644 --- a/aiohttp/streams.py +++ b/aiohttp/streams.py @@ -107,6 +107,8 @@ class StreamReader(AsyncStreamReaderMixin): "_protocol", "_low_water", "_high_water", + "_low_water_chunks", + "_high_water_chunks", "_loop", "_size", "_cursor", @@ -135,6 +137,11 @@ def __init__( self._protocol = protocol self._low_water = limit self._high_water = limit * 2 + # Ensure high_water_chunks >= 3 so it's always > low_water_chunks. + self._high_water_chunks = max(3, limit // 4) + # Use max(2, ...) because there's always at least 1 chunk split remaining + # (the current position), so we need low_water >= 2 to allow resume. + self._low_water_chunks = max(2, self._high_water_chunks // 2) self._loop = loop self._size = 0 self._cursor = 0 @@ -317,6 +324,15 @@ def end_http_chunk_receiving(self) -> None: self._http_chunk_splits.append(self.total_bytes) + # If we get too many small chunks before self._high_water is reached, then any + # .read() call becomes computationally expensive, and could block the event loop + # for too long, hence an additional self._high_water_chunks here. + if ( + len(self._http_chunk_splits) > self._high_water_chunks + and not self._protocol._reading_paused + ): + self._protocol.pause_reading() + # wake up readchunk when end of http chunk received waiter = self._waiter if waiter is not None: @@ -511,7 +527,14 @@ def _read_nowait_chunk(self, n: int) -> bytes: while chunk_splits and chunk_splits[0] < self._cursor: chunk_splits.popleft() - if self._size < self._low_water and self._protocol._reading_paused: + if ( + self._protocol._reading_paused + and self._size < self._low_water + and ( + self._http_chunk_splits is None + or len(self._http_chunk_splits) < self._low_water_chunks + ) + ): self._protocol.resume_reading() return data diff --git a/tests/test_streams.py b/tests/test_streams.py index b59eb77db96..e2fd1659191 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -1537,3 +1537,173 @@ async def test_stream_reader_iter_chunks_chunked_encoding( def test_isinstance_check() -> None: assert isinstance(streams.EMPTY_PAYLOAD, streams.StreamReader) + + +async def test_stream_reader_pause_on_high_water_chunks( + protocol: mock.Mock, +) -> None: + """Test that reading is paused when chunk count exceeds high water mark.""" + loop = asyncio.get_event_loop() + # Use small limit so high_water_chunks is small: limit // 4 = 10 + stream = streams.StreamReader(protocol, limit=40, loop=loop) + + assert stream._high_water_chunks == 10 + assert stream._low_water_chunks == 5 + + # Feed chunks until we exceed high_water_chunks + for i in range(12): + stream.begin_http_chunk_receiving() + stream.feed_data(b"x") # 1 byte per chunk + stream.end_http_chunk_receiving() + + # pause_reading should have been called when chunk count exceeded 10 + protocol.pause_reading.assert_called() + + +async def test_stream_reader_resume_on_low_water_chunks( + protocol: mock.Mock, +) -> None: + """Test that reading resumes when chunk count drops below low water mark.""" + loop = asyncio.get_event_loop() + # Use small limit so high_water_chunks is small: limit // 4 = 10 + stream = streams.StreamReader(protocol, limit=40, loop=loop) + + assert stream._high_water_chunks == 10 + assert stream._low_water_chunks == 5 + + # Feed chunks until we exceed high_water_chunks + for i in range(12): + stream.begin_http_chunk_receiving() + stream.feed_data(b"x") # 1 byte per chunk + stream.end_http_chunk_receiving() + + # Simulate that reading was paused + protocol._reading_paused = True + protocol.pause_reading.reset_mock() + + # Read data to reduce both size and chunk count + # Reading will consume chunks and reduce _http_chunk_splits + data = await stream.read(10) + assert data == b"xxxxxxxxxx" + + # resume_reading should have been called when both size and chunk count + # dropped below their respective low water marks + protocol.resume_reading.assert_called() + + +async def test_stream_reader_no_resume_when_chunks_still_high( + protocol: mock.Mock, +) -> None: + """Test that reading doesn't resume if chunk count is still above low water.""" + loop = asyncio.get_event_loop() + # Use small limit so high_water_chunks is small: limit // 4 = 10 + stream = streams.StreamReader(protocol, limit=40, loop=loop) + + # Feed many chunks + for i in range(12): + stream.begin_http_chunk_receiving() + stream.feed_data(b"x") + stream.end_http_chunk_receiving() + + # Simulate that reading was paused + protocol._reading_paused = True + + # Read only a few bytes - chunk count will still be high + data = await stream.read(2) + assert data == b"xx" + + # resume_reading should NOT be called because chunk count is still >= low_water_chunks + protocol.resume_reading.assert_not_called() + + +async def test_stream_reader_read_non_chunked_response( + protocol: mock.Mock, +) -> None: + """Test that non-chunked responses work correctly (no chunk tracking).""" + loop = asyncio.get_event_loop() + stream = streams.StreamReader(protocol, limit=40, loop=loop) + + # Non-chunked: just feed data without begin/end_http_chunk_receiving + stream.feed_data(b"Hello World") + + # _http_chunk_splits should be None for non-chunked responses + assert stream._http_chunk_splits is None + + # Reading should work without issues + data = await stream.read(5) + assert data == b"Hello" + + data = await stream.read(6) + assert data == b" World" + + +async def test_stream_reader_resume_non_chunked_when_paused( + protocol: mock.Mock, +) -> None: + """Test that resume works for non-chunked responses when paused due to size.""" + loop = asyncio.get_event_loop() + # Small limit so we can trigger pause via size + stream = streams.StreamReader(protocol, limit=10, loop=loop) + + # Feed data that exceeds high_water (limit * 2 = 20) + stream.feed_data(b"x" * 25) + + # Simulate that reading was paused due to size + protocol._reading_paused = True + protocol.pause_reading.assert_called() + + # Read enough to drop below low_water (limit = 10) + data = await stream.read(20) + assert data == b"x" * 20 + + # resume_reading should be called (size is now 5 < low_water 10) + protocol.resume_reading.assert_called() + + +@pytest.mark.parametrize("limit", [1, 2, 4]) +async def test_stream_reader_small_limit_resumes_reading( + protocol: mock.Mock, + limit: int, +) -> None: + """Test that small limits still allow resume_reading to be called. + + Even with very small limits, high_water_chunks should be at least 3 + and low_water_chunks should be at least 2, with high > low to ensure + proper flow control. + """ + loop = asyncio.get_event_loop() + stream = streams.StreamReader(protocol, limit=limit, loop=loop) + + # Verify minimum thresholds are enforced and high > low + assert stream._high_water_chunks >= 3 + assert stream._low_water_chunks >= 2 + assert stream._high_water_chunks > stream._low_water_chunks + + # Set up pause/resume side effects + def pause_reading() -> None: + protocol._reading_paused = True + + protocol.pause_reading.side_effect = pause_reading + + def resume_reading() -> None: + protocol._reading_paused = False + + protocol.resume_reading.side_effect = resume_reading + + # Feed 4 chunks (triggers pause at > high_water_chunks which is >= 3) + for char in b"abcd": + stream.begin_http_chunk_receiving() + stream.feed_data(bytes([char])) + stream.end_http_chunk_receiving() + + # Reading should now be paused + assert protocol._reading_paused is True + assert protocol.pause_reading.called + + # Read all data - should resume (chunk count drops below low_water_chunks) + data = stream.read_nowait() + assert data == b"abcd" + assert stream._size == 0 + + protocol.resume_reading.assert_called() + assert protocol._reading_paused is False From 92477c5a74c43dfe0474bd24f8de11875daa2298 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sat, 3 Jan 2026 14:40:59 +0000 Subject: [PATCH 2/2] Use decompressor max_length parameter (#11898) Co-authored-by: J. Nick Koston --- CHANGES/11898.breaking.rst | 2 + aiohttp/compression_utils.py | 121 ++++++++++++++++++++------------ aiohttp/http_exceptions.py | 4 ++ aiohttp/http_parser.py | 23 +++++- aiohttp/multipart.py | 31 +++++--- aiohttp/web_request.py | 2 +- docs/spelling_wordlist.txt | 1 + pyproject.toml | 4 +- requirements/runtime-deps.in | 4 +- tests/test_client_functional.py | 110 +++++++++++++++++++++++++++++ tests/test_http_parser.py | 34 +++++++++ tests/test_multipart.py | 47 +++++++------ 12 files changed, 297 insertions(+), 86 deletions(-) create mode 100644 CHANGES/11898.breaking.rst diff --git a/CHANGES/11898.breaking.rst b/CHANGES/11898.breaking.rst new file mode 100644 index 00000000000..cfbf2ae4727 --- /dev/null +++ b/CHANGES/11898.breaking.rst @@ -0,0 +1,2 @@ +``Brotli`` and ``brotlicffi`` minimum version is now 1.2. +Decompression now has a default maximum output size of 32MiB per decompress call -- by :user:`Dreamsorcerer`. diff --git a/aiohttp/compression_utils.py b/aiohttp/compression_utils.py index d9c74fa5400..0bc4a30d8ed 100644 --- a/aiohttp/compression_utils.py +++ b/aiohttp/compression_utils.py @@ -1,6 +1,7 @@ import asyncio import sys import zlib +from abc import ABC, abstractmethod from concurrent.futures import Executor from typing import Any, Final, Protocol, TypedDict, cast @@ -32,7 +33,12 @@ HAS_ZSTD = False -MAX_SYNC_CHUNK_SIZE = 1024 +MAX_SYNC_CHUNK_SIZE = 4096 +DEFAULT_MAX_DECOMPRESS_SIZE = 2**25 # 32MiB + +# Unlimited decompression constants - different libraries use different conventions +ZLIB_MAX_LENGTH_UNLIMITED = 0 # zlib uses 0 to mean unlimited +ZSTD_MAX_LENGTH_UNLIMITED = -1 # zstd uses -1 to mean unlimited class ZLibCompressObjProtocol(Protocol): @@ -144,19 +150,37 @@ def encoding_to_mode( return -ZLibBackend.MAX_WBITS if suppress_deflate_header else ZLibBackend.MAX_WBITS -class ZlibBaseHandler: +class DecompressionBaseHandler(ABC): def __init__( self, - mode: int, executor: Executor | None = None, max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE, ): - self._mode = mode + """Base class for decompression handlers.""" self._executor = executor self._max_sync_chunk_size = max_sync_chunk_size + @abstractmethod + def decompress_sync( + self, data: bytes, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED + ) -> bytes: + """Decompress the given data.""" + + async def decompress( + self, data: bytes, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED + ) -> bytes: + """Decompress the given data.""" + if ( + self._max_sync_chunk_size is not None + and len(data) > self._max_sync_chunk_size + ): + return await asyncio.get_event_loop().run_in_executor( + self._executor, self.decompress_sync, data, max_length + ) + return self.decompress_sync(data, max_length) + -class ZLibCompressor(ZlibBaseHandler): +class ZLibCompressor: def __init__( self, encoding: str | None = None, @@ -167,14 +191,12 @@ def __init__( executor: Executor | None = None, max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE, ): - super().__init__( - mode=( - encoding_to_mode(encoding, suppress_deflate_header) - if wbits is None - else wbits - ), - executor=executor, - max_sync_chunk_size=max_sync_chunk_size, + self._executor = executor + self._max_sync_chunk_size = max_sync_chunk_size + self._mode = ( + encoding_to_mode(encoding, suppress_deflate_header) + if wbits is None + else wbits ) self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend) @@ -233,7 +255,7 @@ def flush(self, mode: int | None = None) -> bytes: ) -class ZLibDecompressor(ZlibBaseHandler): +class ZLibDecompressor(DecompressionBaseHandler): def __init__( self, encoding: str | None = None, @@ -241,33 +263,16 @@ def __init__( executor: Executor | None = None, max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE, ): - super().__init__( - mode=encoding_to_mode(encoding, suppress_deflate_header), - executor=executor, - max_sync_chunk_size=max_sync_chunk_size, - ) + super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size) + self._mode = encoding_to_mode(encoding, suppress_deflate_header) self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend) self._decompressor = self._zlib_backend.decompressobj(wbits=self._mode) - def decompress_sync(self, data: Buffer, max_length: int = 0) -> bytes: + def decompress_sync( + self, data: Buffer, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED + ) -> bytes: return self._decompressor.decompress(data, max_length) - async def decompress(self, data: Buffer, max_length: int = 0) -> bytes: - """Decompress the data and return the decompressed bytes. - - If the data size is large than the max_sync_chunk_size, the decompression - will be done in the executor. Otherwise, the decompression will be done - in the event loop. - """ - if ( - self._max_sync_chunk_size is not None - and len(data) > self._max_sync_chunk_size - ): - return await asyncio.get_running_loop().run_in_executor( - self._executor, self._decompressor.decompress, data, max_length - ) - return self.decompress_sync(data, max_length) - def flush(self, length: int = 0) -> bytes: return ( self._decompressor.flush(length) @@ -280,40 +285,64 @@ def eof(self) -> bool: return self._decompressor.eof -class BrotliDecompressor: +class BrotliDecompressor(DecompressionBaseHandler): # Supports both 'brotlipy' and 'Brotli' packages # since they share an import name. The top branches # are for 'brotlipy' and bottom branches for 'Brotli' - def __init__(self) -> None: + def __init__( + self, + executor: Executor | None = None, + max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE, + ) -> None: + """Decompress data using the Brotli library.""" if not HAS_BROTLI: raise RuntimeError( "The brotli decompression is not available. " "Please install `Brotli` module" ) self._obj = brotli.Decompressor() + super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size) - def decompress_sync(self, data: Buffer) -> bytes: + def decompress_sync( + self, data: Buffer, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED + ) -> bytes: + """Decompress the given data.""" if hasattr(self._obj, "decompress"): - return cast(bytes, self._obj.decompress(data)) - return cast(bytes, self._obj.process(data)) + return cast(bytes, self._obj.decompress(data, max_length)) + return cast(bytes, self._obj.process(data, max_length)) def flush(self) -> bytes: + """Flush the decompressor.""" if hasattr(self._obj, "flush"): return cast(bytes, self._obj.flush()) return b"" -class ZSTDDecompressor: - def __init__(self) -> None: +class ZSTDDecompressor(DecompressionBaseHandler): + def __init__( + self, + executor: Executor | None = None, + max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE, + ) -> None: if not HAS_ZSTD: raise RuntimeError( "The zstd decompression is not available. " "Please install `backports.zstd` module" ) self._obj = ZstdDecompressor() - - def decompress_sync(self, data: bytes) -> bytes: - return self._obj.decompress(data) + super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size) + + def decompress_sync( + self, data: bytes, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED + ) -> bytes: + # zstd uses -1 for unlimited, while zlib uses 0 for unlimited + # Convert the zlib convention (0=unlimited) to zstd convention (-1=unlimited) + zstd_max_length = ( + ZSTD_MAX_LENGTH_UNLIMITED + if max_length == ZLIB_MAX_LENGTH_UNLIMITED + else max_length + ) + return self._obj.decompress(data, zstd_max_length) def flush(self) -> bytes: return b"" diff --git a/aiohttp/http_exceptions.py b/aiohttp/http_exceptions.py index 544e1d03a25..91abed24308 100644 --- a/aiohttp/http_exceptions.py +++ b/aiohttp/http_exceptions.py @@ -73,6 +73,10 @@ class ContentLengthError(PayloadEncodingError): """Not enough data to satisfy content length header.""" +class DecompressSizeError(PayloadEncodingError): + """Decompressed size exceeds the configured limit.""" + + class LineTooLong(BadHttpMessage): def __init__( self, line: str, limit: str = "Unknown", actual_size: str = "Unknown" diff --git a/aiohttp/http_parser.py b/aiohttp/http_parser.py index 63bae7faf3c..79c7f06734f 100644 --- a/aiohttp/http_parser.py +++ b/aiohttp/http_parser.py @@ -13,6 +13,7 @@ from . import hdrs from .base_protocol import BaseProtocol from .compression_utils import ( + DEFAULT_MAX_DECOMPRESS_SIZE, HAS_BROTLI, HAS_ZSTD, BrotliDecompressor, @@ -34,6 +35,7 @@ BadStatusLine, ContentEncodingError, ContentLengthError, + DecompressSizeError, InvalidHeader, InvalidURLError, LineTooLong, @@ -921,7 +923,12 @@ def feed_data( class DeflateBuffer: """DeflateStream decompress stream and feed data into specified stream.""" - def __init__(self, out: StreamReader, encoding: str | None) -> None: + def __init__( + self, + out: StreamReader, + encoding: str | None, + max_decompress_size: int = DEFAULT_MAX_DECOMPRESS_SIZE, + ) -> None: self.out = out self.size = 0 out.total_compressed_bytes = self.size @@ -946,6 +953,8 @@ def __init__(self, out: StreamReader, encoding: str | None) -> None: else: self.decompressor = ZLibDecompressor(encoding=encoding) + self._max_decompress_size = max_decompress_size + def set_exception( self, exc: type[BaseException] | BaseException, @@ -975,7 +984,10 @@ def feed_data(self, chunk: bytes) -> None: ) try: - chunk = self.decompressor.decompress_sync(chunk) + # Decompress with limit + 1 so we can detect if output exceeds limit + chunk = self.decompressor.decompress_sync( + chunk, max_length=self._max_decompress_size + 1 + ) except Exception: raise ContentEncodingError( "Can not decode content-encoding: %s" % self.encoding @@ -983,6 +995,13 @@ def feed_data(self, chunk: bytes) -> None: self._started_decoding = True + # Check if decompression limit was exceeded + if len(chunk) > self._max_decompress_size: + raise DecompressSizeError( + "Decompressed data exceeds the configured limit of %d bytes" + % self._max_decompress_size + ) + if chunk: self.out.feed_data(chunk) diff --git a/aiohttp/multipart.py b/aiohttp/multipart.py index c8f216956d7..af935232772 100644 --- a/aiohttp/multipart.py +++ b/aiohttp/multipart.py @@ -13,7 +13,12 @@ from multidict import CIMultiDict, CIMultiDictProxy -from .compression_utils import ZLibCompressor, ZLibDecompressor +from .abc import AbstractStreamWriter +from .compression_utils import ( + DEFAULT_MAX_DECOMPRESS_SIZE, + ZLibCompressor, + ZLibDecompressor, +) from .hdrs import ( CONTENT_DISPOSITION, CONTENT_ENCODING, @@ -261,6 +266,7 @@ def __init__( *, subtype: str = "mixed", default_charset: str | None = None, + max_decompress_size: int = DEFAULT_MAX_DECOMPRESS_SIZE, ) -> None: self.headers = headers self._boundary = boundary @@ -277,6 +283,7 @@ def __init__( self._prev_chunk: bytes | None = None self._content_eof = 0 self._cache: dict[str, Any] = {} + self._max_decompress_size = max_decompress_size def __aiter__(self) -> Self: return self @@ -307,7 +314,7 @@ async def read(self, *, decode: bool = False) -> bytes: data.extend(await self.read_chunk(self.chunk_size)) # https://github.com/python/mypy/issues/17537 if decode: # type: ignore[unreachable] - return self.decode(data) + return await self.decode(data) return data async def read_chunk(self, size: int = chunk_size) -> bytes: @@ -485,7 +492,7 @@ def at_eof(self) -> bool: """Returns True if the boundary was reached or False otherwise.""" return self._at_eof - def decode(self, data: bytes) -> bytes: + async def decode(self, data: bytes) -> bytes: """Decodes data. Decoding is done according the specified Content-Encoding @@ -495,18 +502,18 @@ def decode(self, data: bytes) -> bytes: data = self._decode_content_transfer(data) # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8 if not self._is_form_data and CONTENT_ENCODING in self.headers: - return self._decode_content(data) + return await self._decode_content(data) return data - def _decode_content(self, data: bytes) -> bytes: + async def _decode_content(self, data: bytes) -> bytes: encoding = self.headers.get(CONTENT_ENCODING, "").lower() if encoding == "identity": return data if encoding in {"deflate", "gzip"}: - return ZLibDecompressor( + return await ZLibDecompressor( encoding=encoding, suppress_deflate_header=True, - ).decompress_sync(data) + ).decompress(data, max_length=self._max_decompress_size) raise RuntimeError(f"unknown content encoding: {encoding}") @@ -577,11 +584,11 @@ async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> byt """ raise TypeError("Unable to read body part as bytes. Use write() to consume.") - async def write(self, writer: Any) -> None: + async def write(self, writer: AbstractStreamWriter) -> None: field = self._value chunk = await field.read_chunk(size=2**16) while chunk: - await writer.write(field.decode(chunk)) + await writer.write(await field.decode(chunk)) chunk = await field.read_chunk(size=2**16) @@ -1024,7 +1031,9 @@ async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> byt return b"".join(parts) - async def write(self, writer: Any, close_boundary: bool = True) -> None: + async def write( + self, writer: AbstractStreamWriter, close_boundary: bool = True + ) -> None: """Write body.""" for part, encoding, te_encoding in self._parts: if self._is_form_data: @@ -1078,7 +1087,7 @@ async def close(self) -> None: class MultipartPayloadWriter: - def __init__(self, writer: Any) -> None: + def __init__(self, writer: AbstractStreamWriter) -> None: self._writer = writer self._encoding: str | None = None self._compress: ZLibCompressor | None = None diff --git a/aiohttp/web_request.py b/aiohttp/web_request.py index 2300becb3a1..4a2e6f0bf8e 100644 --- a/aiohttp/web_request.py +++ b/aiohttp/web_request.py @@ -716,7 +716,7 @@ async def post(self) -> "MultiDictProxy[str | bytes | FileField]": ) chunk = await field.read_chunk(size=2**16) while chunk: - chunk = field.decode(chunk) + chunk = await field.decode(chunk) await self._loop.run_in_executor(None, tmp.write, chunk) size += len(chunk) if 0 < max_size < size: diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 29f6b0f364e..4e7bab6968c 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -193,6 +193,7 @@ lowercased Mako manylinux metadata +MiB microservice middleware middlewares diff --git a/pyproject.toml b/pyproject.toml index 0cfa7a3221b..7ad4f54b0c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,8 +49,8 @@ dynamic = [ [project.optional-dependencies] speedups = [ "aiodns >= 3.3.0", - "Brotli; platform_python_implementation == 'CPython'", - "brotlicffi; platform_python_implementation != 'CPython'", + "Brotli >= 1.2; platform_python_implementation == 'CPython'", + "brotlicffi >= 1.2; platform_python_implementation != 'CPython'", "backports.zstd; platform_python_implementation == 'CPython' and python_version < '3.14'", ] diff --git a/requirements/runtime-deps.in b/requirements/runtime-deps.in index 16515e7551a..3e3c4d05313 100644 --- a/requirements/runtime-deps.in +++ b/requirements/runtime-deps.in @@ -5,8 +5,8 @@ aiohappyeyeballs >= 2.5.0 aiosignal >= 1.4.0 async-timeout >= 4.0, < 6.0 ; python_version < '3.11' backports.zstd; platform_python_implementation == 'CPython' and python_version < '3.14' -Brotli; platform_python_implementation == 'CPython' -brotlicffi; platform_python_implementation != 'CPython' +Brotli >= 1.2; platform_python_implementation == 'CPython' +brotlicffi >= 1.2; platform_python_implementation != 'CPython' frozenlist >= 1.1.1 multidict >=4.5, < 7.0 propcache >= 0.2.0 diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 95b40cce9bb..916261a1d12 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -13,11 +13,25 @@ import tarfile import time import zipfile +import zlib from collections.abc import AsyncIterator, Awaitable, Callable from contextlib import suppress from typing import Any, NoReturn from unittest import mock +try: + try: + import brotlicffi as brotli + except ImportError: + import brotli +except ImportError: + brotli = None # pragma: no cover + +try: + from backports.zstd import ZstdCompressor +except ImportError: + ZstdCompressor = None # type: ignore[assignment,misc] # pragma: no cover + import pytest import trustme from multidict import MultiDict @@ -38,6 +52,8 @@ TooManyRedirects, ) from aiohttp.client_reqrep import ClientRequest +from aiohttp.compression_utils import DEFAULT_MAX_DECOMPRESS_SIZE +from aiohttp.http_exceptions import DecompressSizeError from aiohttp.payload import ( AsyncIterablePayload, BufferedReaderPayload, @@ -2364,6 +2380,100 @@ async def handler(request: web.Request) -> web.Response: await resp.read() +async def test_payload_decompress_size_limit(aiohttp_client: AiohttpClient) -> None: + """Test that decompression size limit triggers DecompressSizeError. + + When a compressed payload expands beyond the configured limit, + we raise DecompressSizeError. + """ + # Create a highly compressible payload that exceeds the decompression limit. + # 64MiB of repeated bytes compresses to ~32KB but expands beyond the + # 32MiB per-call limit. + original = b"A" * (64 * 2**20) + compressed = zlib.compress(original) + assert len(original) > DEFAULT_MAX_DECOMPRESS_SIZE + + async def handler(request: web.Request) -> web.Response: + # Send compressed data with Content-Encoding header + resp = web.Response(body=compressed) + resp.headers["Content-Encoding"] = "deflate" + return resp + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + async with client.get("/") as resp: + assert resp.status == 200 + + with pytest.raises(aiohttp.ClientPayloadError) as exc_info: + await resp.read() + + assert isinstance(exc_info.value.__cause__, DecompressSizeError) + assert "Decompressed data exceeds" in str(exc_info.value.__cause__) + + +@pytest.mark.skipif(brotli is None, reason="brotli is not installed") +async def test_payload_decompress_size_limit_brotli( + aiohttp_client: AiohttpClient, +) -> None: + """Test that brotli decompression size limit triggers DecompressSizeError.""" + assert brotli is not None + # Create a highly compressible payload that exceeds the decompression limit. + original = b"A" * (64 * 2**20) + compressed = brotli.compress(original) + assert len(original) > DEFAULT_MAX_DECOMPRESS_SIZE + + async def handler(request: web.Request) -> web.Response: + resp = web.Response(body=compressed) + resp.headers["Content-Encoding"] = "br" + return resp + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + async with client.get("/") as resp: + assert resp.status == 200 + + with pytest.raises(aiohttp.ClientPayloadError) as exc_info: + await resp.read() + + assert isinstance(exc_info.value.__cause__, DecompressSizeError) + assert "Decompressed data exceeds" in str(exc_info.value.__cause__) + + +@pytest.mark.skipif(ZstdCompressor is None, reason="backports.zstd is not installed") +async def test_payload_decompress_size_limit_zstd( + aiohttp_client: AiohttpClient, +) -> None: + """Test that zstd decompression size limit triggers DecompressSizeError.""" + assert ZstdCompressor is not None + # Create a highly compressible payload that exceeds the decompression limit. + original = b"A" * (64 * 2**20) + compressor = ZstdCompressor() + compressed = compressor.compress(original) + compressor.flush() + assert len(original) > DEFAULT_MAX_DECOMPRESS_SIZE + + async def handler(request: web.Request) -> web.Response: + resp = web.Response(body=compressed) + resp.headers["Content-Encoding"] = "zstd" + return resp + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + async with client.get("/") as resp: + assert resp.status == 200 + + with pytest.raises(aiohttp.ClientPayloadError) as exc_info: + await resp.read() + + assert isinstance(exc_info.value.__cause__, DecompressSizeError) + assert "Decompressed data exceeds" in str(exc_info.value.__cause__) + + async def test_bad_payload_chunked_encoding(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.StreamResponse: resp = web.StreamResponse() diff --git a/tests/test_http_parser.py b/tests/test_http_parser.py index ef597ccca46..869799583e1 100644 --- a/tests/test_http_parser.py +++ b/tests/test_http_parser.py @@ -3,6 +3,7 @@ import asyncio import re import sys +import zlib from collections.abc import Iterable from contextlib import suppress from typing import Any @@ -1970,3 +1971,36 @@ async def test_empty_body(self, protocol: BaseProtocol) -> None: dbuf.feed_eof() assert buf.at_eof() + + @pytest.mark.parametrize( + "chunk_size", + [1024, 2**14, 2**16], # 1KB, 16KB, 64KB + ids=["1KB", "16KB", "64KB"], + ) + async def test_streaming_decompress_large_payload( + self, protocol: BaseProtocol, chunk_size: int + ) -> None: + """Test that large payloads decompress correctly when streamed in chunks. + + This simulates real HTTP streaming where compressed data arrives in + small network chunks. Each chunk's decompressed output should be within + the max_decompress_size limit, allowing full recovery of the original data. + """ + # Create a large payload (3MiB) that compresses well + original = b"A" * (3 * 2**20) + compressed = zlib.compress(original) + + buf = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + dbuf = DeflateBuffer(buf, "deflate") + + # Feed compressed data in chunks (simulating network streaming) + for i in range(0, len(compressed), chunk_size): + chunk = compressed[i : i + chunk_size] + dbuf.feed_data(chunk) + + dbuf.feed_eof() + + # Read all decompressed data + result = b"".join(buf._buffer) + assert len(result) == len(original) + assert result == original diff --git a/tests/test_multipart.py b/tests/test_multipart.py index fdc0389c361..395f620aed2 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -11,6 +11,7 @@ import aiohttp from aiohttp import payload +from aiohttp.abc import AbstractStreamWriter from aiohttp.compression_utils import ZLibBackend from aiohttp.hdrs import ( CONTENT_DISPOSITION, @@ -43,14 +44,14 @@ def buf() -> bytearray: @pytest.fixture -def stream(buf: bytearray) -> mock.Mock: - writer = mock.Mock() +def stream(buf: bytearray) -> AbstractStreamWriter: + writer = mock.create_autospec(AbstractStreamWriter, instance=True, spec_set=True) async def write(chunk: bytes) -> None: buf.extend(chunk) writer.write.side_effect = write - return writer + return writer # type: ignore[no-any-return] @pytest.fixture @@ -393,7 +394,7 @@ async def test_decode_with_content_transfer_encoding_base64(self) -> None: result = b"" while not obj.at_eof(): chunk = await obj.read_chunk(size=6) - result += obj.decode(chunk) + result += await obj.decode(chunk) assert b"Time to Relax!" == result async def test_read_with_content_transfer_encoding_quoted_printable(self) -> None: @@ -1031,7 +1032,7 @@ async def test_writer(writer: aiohttp.MultipartWriter) -> None: async def test_writer_serialize_io_chunk( - buf: bytearray, stream: Stream, writer: aiohttp.MultipartWriter + buf: bytearray, stream: AbstractStreamWriter, writer: aiohttp.MultipartWriter ) -> None: with io.BytesIO(b"foobarbaz") as file_handle: writer.append(file_handle) @@ -1043,7 +1044,7 @@ async def test_writer_serialize_io_chunk( async def test_writer_serialize_json( - buf: bytearray, stream: Stream, writer: aiohttp.MultipartWriter + buf: bytearray, stream: AbstractStreamWriter, writer: aiohttp.MultipartWriter ) -> None: writer.append_json({"привет": "мир"}) await writer.write(stream) @@ -1054,7 +1055,7 @@ async def test_writer_serialize_json( async def test_writer_serialize_form( - buf: bytearray, stream: Stream, writer: aiohttp.MultipartWriter + buf: bytearray, stream: AbstractStreamWriter, writer: aiohttp.MultipartWriter ) -> None: data = [("foo", "bar"), ("foo", "baz"), ("boo", "zoo")] writer.append_form(data) @@ -1064,7 +1065,7 @@ async def test_writer_serialize_form( async def test_writer_serialize_form_dict( - buf: bytearray, stream: Stream, writer: aiohttp.MultipartWriter + buf: bytearray, stream: AbstractStreamWriter, writer: aiohttp.MultipartWriter ) -> None: data = {"hello": "мир"} writer.append_form(data) @@ -1074,7 +1075,7 @@ async def test_writer_serialize_form_dict( async def test_writer_write( - buf: bytearray, stream: Stream, writer: aiohttp.MultipartWriter + buf: bytearray, stream: AbstractStreamWriter, writer: aiohttp.MultipartWriter ) -> None: writer.append("foo-bar-baz") writer.append_json({"test": "passed"}) @@ -1121,7 +1122,9 @@ async def test_writer_write( ) == bytes(buf) -async def test_writer_write_no_close_boundary(buf: bytearray, stream: Stream) -> None: +async def test_writer_write_no_close_boundary( + buf: bytearray, stream: AbstractStreamWriter +) -> None: writer = aiohttp.MultipartWriter(boundary=":") writer.append("foo-bar-baz") writer.append_json({"test": "passed"}) @@ -1154,7 +1157,7 @@ async def test_writer_write_no_close_boundary(buf: bytearray, stream: Stream) -> async def test_writer_write_no_parts( - buf: bytearray, stream: Stream, writer: aiohttp.MultipartWriter + buf: bytearray, stream: AbstractStreamWriter, writer: aiohttp.MultipartWriter ) -> None: await writer.write(stream) assert b"--:--\r\n" == bytes(buf) @@ -1163,7 +1166,7 @@ async def test_writer_write_no_parts( @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_writer_serialize_with_content_encoding_gzip( buf: bytearray, - stream: Stream, + stream: AbstractStreamWriter, writer: aiohttp.MultipartWriter, ) -> None: writer.append("Time to Relax!", {CONTENT_ENCODING: "gzip"}) @@ -1182,7 +1185,7 @@ async def test_writer_serialize_with_content_encoding_gzip( async def test_writer_serialize_with_content_encoding_deflate( - buf: bytearray, stream: Stream, writer: aiohttp.MultipartWriter + buf: bytearray, stream: AbstractStreamWriter, writer: aiohttp.MultipartWriter ) -> None: writer.append("Time to Relax!", {CONTENT_ENCODING: "deflate"}) await writer.write(stream) @@ -1198,7 +1201,7 @@ async def test_writer_serialize_with_content_encoding_deflate( async def test_writer_serialize_with_content_encoding_identity( - buf: bytearray, stream: Stream, writer: aiohttp.MultipartWriter + buf: bytearray, stream: AbstractStreamWriter, writer: aiohttp.MultipartWriter ) -> None: thing = b"\x0b\xc9\xccMU(\xc9W\x08J\xcdI\xacP\x04\x00" writer.append(thing, {CONTENT_ENCODING: "identity"}) @@ -1215,14 +1218,14 @@ async def test_writer_serialize_with_content_encoding_identity( def test_writer_serialize_with_content_encoding_unknown( - buf: bytearray, stream: Stream, writer: aiohttp.MultipartWriter + buf: bytearray, stream: AbstractStreamWriter, writer: aiohttp.MultipartWriter ) -> None: with pytest.raises(RuntimeError): writer.append("Time to Relax!", {CONTENT_ENCODING: "snappy"}) async def test_writer_with_content_transfer_encoding_base64( - buf: bytearray, stream: Stream, writer: aiohttp.MultipartWriter + buf: bytearray, stream: AbstractStreamWriter, writer: aiohttp.MultipartWriter ) -> None: writer.append("Time to Relax!", {CONTENT_TRANSFER_ENCODING: "base64"}) await writer.write(stream) @@ -1237,7 +1240,7 @@ async def test_writer_with_content_transfer_encoding_base64( async def test_writer_content_transfer_encoding_quote_printable( - buf: bytearray, stream: Stream, writer: aiohttp.MultipartWriter + buf: bytearray, stream: AbstractStreamWriter, writer: aiohttp.MultipartWriter ) -> None: writer.append("Привет, мир!", {CONTENT_TRANSFER_ENCODING: "quoted-printable"}) await writer.write(stream) @@ -1255,7 +1258,7 @@ async def test_writer_content_transfer_encoding_quote_printable( def test_writer_content_transfer_encoding_unknown( - buf: bytearray, stream: Stream, writer: aiohttp.MultipartWriter + buf: bytearray, stream: AbstractStreamWriter, writer: aiohttp.MultipartWriter ) -> None: with pytest.raises(RuntimeError): writer.append("Time to Relax!", {CONTENT_TRANSFER_ENCODING: "unknown"}) @@ -1385,7 +1388,7 @@ def test_append_none_not_allowed(self) -> None: writer.append(None) async def test_write_preserves_content_disposition( - self, buf: bytearray, stream: Stream + self, buf: bytearray, stream: AbstractStreamWriter ) -> None: with aiohttp.MultipartWriter(boundary=":") as writer: part = writer.append(b"foo", headers={CONTENT_TYPE: "test/passed"}) @@ -1404,7 +1407,7 @@ async def test_write_preserves_content_disposition( assert message == b"foo\r\n--:--\r\n" async def test_preserve_content_disposition_header( - self, buf: bytearray, stream: Stream + self, buf: bytearray, stream: AbstractStreamWriter ) -> None: # https://github.com/aio-libs/aiohttp/pull/3475#issuecomment-451072381 with pathlib.Path(__file__).open("rb") as fobj: @@ -1430,7 +1433,7 @@ async def test_preserve_content_disposition_header( ) async def test_set_content_disposition_override( - self, buf: bytearray, stream: Stream + self, buf: bytearray, stream: AbstractStreamWriter ) -> None: # https://github.com/aio-libs/aiohttp/pull/3475#issuecomment-451072381 with pathlib.Path(__file__).open("rb") as fobj: @@ -1456,7 +1459,7 @@ async def test_set_content_disposition_override( ) async def test_reset_content_disposition_header( - self, buf: bytearray, stream: Stream + self, buf: bytearray, stream: AbstractStreamWriter ) -> None: # https://github.com/aio-libs/aiohttp/pull/3475#issuecomment-451072381 with pathlib.Path(__file__).open("rb") as fobj: