Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES/11898.breaking.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
``Brotli`` and ``brotlicffi`` minimum version is now 1.2.
Decompression now has a default maximum output size of 32MiB per decompress call -- by :user:`Dreamsorcerer`.
121 changes: 75 additions & 46 deletions aiohttp/compression_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import sys
import zlib
from abc import ABC, abstractmethod
from concurrent.futures import Executor
from typing import Any, Final, Protocol, TypedDict, cast

Expand Down Expand Up @@ -32,7 +33,12 @@
HAS_ZSTD = False


MAX_SYNC_CHUNK_SIZE = 1024
MAX_SYNC_CHUNK_SIZE = 4096
DEFAULT_MAX_DECOMPRESS_SIZE = 2**25 # 32MiB

# Unlimited decompression constants - different libraries use different conventions
ZLIB_MAX_LENGTH_UNLIMITED = 0 # zlib uses 0 to mean unlimited
ZSTD_MAX_LENGTH_UNLIMITED = -1 # zstd uses -1 to mean unlimited


class ZLibCompressObjProtocol(Protocol):
Expand Down Expand Up @@ -144,19 +150,37 @@ def encoding_to_mode(
return -ZLibBackend.MAX_WBITS if suppress_deflate_header else ZLibBackend.MAX_WBITS


class ZlibBaseHandler:
class DecompressionBaseHandler(ABC):
def __init__(
self,
mode: int,
executor: Executor | None = None,
max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE,
):
self._mode = mode
"""Base class for decompression handlers."""
self._executor = executor
self._max_sync_chunk_size = max_sync_chunk_size

@abstractmethod
def decompress_sync(
self, data: bytes, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
) -> bytes:
"""Decompress the given data."""

async def decompress(
self, data: bytes, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
) -> bytes:
"""Decompress the given data."""
if (
self._max_sync_chunk_size is not None
and len(data) > self._max_sync_chunk_size
):
return await asyncio.get_event_loop().run_in_executor(
self._executor, self.decompress_sync, data, max_length
)
return self.decompress_sync(data, max_length)


class ZLibCompressor(ZlibBaseHandler):
class ZLibCompressor:
def __init__(
self,
encoding: str | None = None,
Expand All @@ -167,14 +191,12 @@ def __init__(
executor: Executor | None = None,
max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE,
):
super().__init__(
mode=(
encoding_to_mode(encoding, suppress_deflate_header)
if wbits is None
else wbits
),
executor=executor,
max_sync_chunk_size=max_sync_chunk_size,
self._executor = executor
self._max_sync_chunk_size = max_sync_chunk_size
self._mode = (
encoding_to_mode(encoding, suppress_deflate_header)
if wbits is None
else wbits
)
self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend)

Expand Down Expand Up @@ -233,41 +255,24 @@ def flush(self, mode: int | None = None) -> bytes:
)


class ZLibDecompressor(ZlibBaseHandler):
class ZLibDecompressor(DecompressionBaseHandler):
def __init__(
self,
encoding: str | None = None,
suppress_deflate_header: bool = False,
executor: Executor | None = None,
max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE,
):
super().__init__(
mode=encoding_to_mode(encoding, suppress_deflate_header),
executor=executor,
max_sync_chunk_size=max_sync_chunk_size,
)
super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size)
self._mode = encoding_to_mode(encoding, suppress_deflate_header)
self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend)
self._decompressor = self._zlib_backend.decompressobj(wbits=self._mode)

def decompress_sync(self, data: Buffer, max_length: int = 0) -> bytes:
def decompress_sync(
self, data: Buffer, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
) -> bytes:
return self._decompressor.decompress(data, max_length)

async def decompress(self, data: Buffer, max_length: int = 0) -> bytes:
"""Decompress the data and return the decompressed bytes.

If the data size is large than the max_sync_chunk_size, the decompression
will be done in the executor. Otherwise, the decompression will be done
in the event loop.
"""
if (
self._max_sync_chunk_size is not None
and len(data) > self._max_sync_chunk_size
):
return await asyncio.get_running_loop().run_in_executor(
self._executor, self._decompressor.decompress, data, max_length
)
return self.decompress_sync(data, max_length)

def flush(self, length: int = 0) -> bytes:
return (
self._decompressor.flush(length)
Expand All @@ -280,40 +285,64 @@ def eof(self) -> bool:
return self._decompressor.eof


class BrotliDecompressor:
class BrotliDecompressor(DecompressionBaseHandler):
# Supports both 'brotlipy' and 'Brotli' packages
# since they share an import name. The top branches
# are for 'brotlipy' and bottom branches for 'Brotli'
def __init__(self) -> None:
def __init__(
self,
executor: Executor | None = None,
max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE,
) -> None:
"""Decompress data using the Brotli library."""
if not HAS_BROTLI:
raise RuntimeError(
"The brotli decompression is not available. "
"Please install `Brotli` module"
)
self._obj = brotli.Decompressor()
super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size)

def decompress_sync(self, data: Buffer) -> bytes:
def decompress_sync(
self, data: Buffer, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
) -> bytes:
"""Decompress the given data."""
if hasattr(self._obj, "decompress"):
return cast(bytes, self._obj.decompress(data))
return cast(bytes, self._obj.process(data))
return cast(bytes, self._obj.decompress(data, max_length))
return cast(bytes, self._obj.process(data, max_length))

def flush(self) -> bytes:
"""Flush the decompressor."""
if hasattr(self._obj, "flush"):
return cast(bytes, self._obj.flush())
return b""


class ZSTDDecompressor:
def __init__(self) -> None:
class ZSTDDecompressor(DecompressionBaseHandler):
def __init__(
self,
executor: Executor | None = None,
max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE,
) -> None:
if not HAS_ZSTD:
raise RuntimeError(
"The zstd decompression is not available. "
"Please install `backports.zstd` module"
)
self._obj = ZstdDecompressor()

def decompress_sync(self, data: bytes) -> bytes:
return self._obj.decompress(data)
super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size)

def decompress_sync(
self, data: bytes, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
) -> bytes:
# zstd uses -1 for unlimited, while zlib uses 0 for unlimited
# Convert the zlib convention (0=unlimited) to zstd convention (-1=unlimited)
zstd_max_length = (
ZSTD_MAX_LENGTH_UNLIMITED
if max_length == ZLIB_MAX_LENGTH_UNLIMITED
else max_length
)
return self._obj.decompress(data, zstd_max_length)

def flush(self) -> bytes:
return b""
4 changes: 4 additions & 0 deletions aiohttp/http_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ class ContentLengthError(PayloadEncodingError):
"""Not enough data to satisfy content length header."""


class DecompressSizeError(PayloadEncodingError):
"""Decompressed size exceeds the configured limit."""


class LineTooLong(BadHttpMessage):
def __init__(
self, line: str, limit: str = "Unknown", actual_size: str = "Unknown"
Expand Down
23 changes: 21 additions & 2 deletions aiohttp/http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from . import hdrs
from .base_protocol import BaseProtocol
from .compression_utils import (
DEFAULT_MAX_DECOMPRESS_SIZE,
HAS_BROTLI,
HAS_ZSTD,
BrotliDecompressor,
Expand All @@ -34,6 +35,7 @@
BadStatusLine,
ContentEncodingError,
ContentLengthError,
DecompressSizeError,
InvalidHeader,
InvalidURLError,
LineTooLong,
Expand Down Expand Up @@ -921,7 +923,12 @@ def feed_data(
class DeflateBuffer:
"""DeflateStream decompress stream and feed data into specified stream."""

def __init__(self, out: StreamReader, encoding: str | None) -> None:
def __init__(
self,
out: StreamReader,
encoding: str | None,
max_decompress_size: int = DEFAULT_MAX_DECOMPRESS_SIZE,
) -> None:
self.out = out
self.size = 0
out.total_compressed_bytes = self.size
Expand All @@ -946,6 +953,8 @@ def __init__(self, out: StreamReader, encoding: str | None) -> None:
else:
self.decompressor = ZLibDecompressor(encoding=encoding)

self._max_decompress_size = max_decompress_size

def set_exception(
self,
exc: type[BaseException] | BaseException,
Expand Down Expand Up @@ -975,14 +984,24 @@ def feed_data(self, chunk: bytes) -> None:
)

try:
chunk = self.decompressor.decompress_sync(chunk)
# Decompress with limit + 1 so we can detect if output exceeds limit
chunk = self.decompressor.decompress_sync(
chunk, max_length=self._max_decompress_size + 1
)
except Exception:
raise ContentEncodingError(
"Can not decode content-encoding: %s" % self.encoding
)

self._started_decoding = True

# Check if decompression limit was exceeded
if len(chunk) > self._max_decompress_size:
raise DecompressSizeError(
"Decompressed data exceeds the configured limit of %d bytes"
% self._max_decompress_size
)

if chunk:
self.out.feed_data(chunk)

Expand Down
31 changes: 20 additions & 11 deletions aiohttp/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@

from multidict import CIMultiDict, CIMultiDictProxy

from .compression_utils import ZLibCompressor, ZLibDecompressor
from .abc import AbstractStreamWriter
from .compression_utils import (
DEFAULT_MAX_DECOMPRESS_SIZE,
ZLibCompressor,
ZLibDecompressor,
)
from .hdrs import (
CONTENT_DISPOSITION,
CONTENT_ENCODING,
Expand Down Expand Up @@ -261,6 +266,7 @@ def __init__(
*,
subtype: str = "mixed",
default_charset: str | None = None,
max_decompress_size: int = DEFAULT_MAX_DECOMPRESS_SIZE,
) -> None:
self.headers = headers
self._boundary = boundary
Expand All @@ -277,6 +283,7 @@ def __init__(
self._prev_chunk: bytes | None = None
self._content_eof = 0
self._cache: dict[str, Any] = {}
self._max_decompress_size = max_decompress_size

def __aiter__(self) -> Self:
return self
Expand Down Expand Up @@ -307,7 +314,7 @@ async def read(self, *, decode: bool = False) -> bytes:
data.extend(await self.read_chunk(self.chunk_size))
# https://github.com/python/mypy/issues/17537
if decode: # type: ignore[unreachable]
return self.decode(data)
return await self.decode(data)
return data

async def read_chunk(self, size: int = chunk_size) -> bytes:
Expand Down Expand Up @@ -485,7 +492,7 @@ def at_eof(self) -> bool:
"""Returns True if the boundary was reached or False otherwise."""
return self._at_eof

def decode(self, data: bytes) -> bytes:
async def decode(self, data: bytes) -> bytes:
"""Decodes data.

Decoding is done according the specified Content-Encoding
Expand All @@ -495,18 +502,18 @@ def decode(self, data: bytes) -> bytes:
data = self._decode_content_transfer(data)
# https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
if not self._is_form_data and CONTENT_ENCODING in self.headers:
return self._decode_content(data)
return await self._decode_content(data)
return data

def _decode_content(self, data: bytes) -> bytes:
async def _decode_content(self, data: bytes) -> bytes:
encoding = self.headers.get(CONTENT_ENCODING, "").lower()
if encoding == "identity":
return data
if encoding in {"deflate", "gzip"}:
return ZLibDecompressor(
return await ZLibDecompressor(
encoding=encoding,
suppress_deflate_header=True,
).decompress_sync(data)
).decompress(data, max_length=self._max_decompress_size)

raise RuntimeError(f"unknown content encoding: {encoding}")

Expand Down Expand Up @@ -577,11 +584,11 @@ async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> byt
"""
raise TypeError("Unable to read body part as bytes. Use write() to consume.")

async def write(self, writer: Any) -> None:
async def write(self, writer: AbstractStreamWriter) -> None:
field = self._value
chunk = await field.read_chunk(size=2**16)
while chunk:
await writer.write(field.decode(chunk))
await writer.write(await field.decode(chunk))
chunk = await field.read_chunk(size=2**16)


Expand Down Expand Up @@ -1024,7 +1031,9 @@ async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> byt

return b"".join(parts)

async def write(self, writer: Any, close_boundary: bool = True) -> None:
async def write(
self, writer: AbstractStreamWriter, close_boundary: bool = True
) -> None:
"""Write body."""
for part, encoding, te_encoding in self._parts:
if self._is_form_data:
Expand Down Expand Up @@ -1078,7 +1087,7 @@ async def close(self) -> None:


class MultipartPayloadWriter:
def __init__(self, writer: Any) -> None:
def __init__(self, writer: AbstractStreamWriter) -> None:
self._writer = writer
self._encoding: str | None = None
self._compress: ZLibCompressor | None = None
Expand Down
Loading
Loading