From 2aa67b29f41bd8ab79c0e6ee86bf09c5a1506b8e Mon Sep 17 00:00:00 2001 From: Cal Nightingale Date: Wed, 28 Jan 2026 16:37:39 -0700 Subject: [PATCH] fix: add max_length support to Gzip/Brotli decoders for urllib3 2.6+ compatibility urllib3 2.6+ passes max_length to decoder.decompress(). Update _GzipDecoder and _BrotliDecoder to accept and forward max_length, with a TypeError fallback for older urllib3. Add has_unconsumed_tail property to _BrotliDecoder (the proxy class that needs it; _GzipDecoder inherits it from the parent). Fixes #491 --- .../requests/download.py | 11 ++- google/resumable_media/requests/download.py | 25 ++++++- tests/unit/requests/test_download.py | 73 +++++++++++++++++++ tests_async/unit/requests/test_download.py | 17 +++++ 4 files changed, 120 insertions(+), 6 deletions(-) diff --git a/google/_async_resumable_media/requests/download.py b/google/_async_resumable_media/requests/download.py index d4af79d9..4781c875 100644 --- a/google/_async_resumable_media/requests/download.py +++ b/google/_async_resumable_media/requests/download.py @@ -452,14 +452,21 @@ def __init__(self, checksum): super(_GzipDecoder, self).__init__() self._checksum = checksum - def decompress(self, data): + def decompress(self, data, max_length=-1): """Decompress the bytes. Args: data (bytes): The compressed bytes to be decompressed. + max_length (int): Maximum number of bytes to return. -1 for no + limit. Forwarded to the underlying decoder when supported. Returns: bytes: The decompressed bytes from ``data``. """ self._checksum.update(data) - return super(_GzipDecoder, self).decompress(data) + try: + return super(_GzipDecoder, self).decompress( + data, max_length=max_length + ) + except TypeError: + return super(_GzipDecoder, self).decompress(data) diff --git a/google/resumable_media/requests/download.py b/google/resumable_media/requests/download.py index 1472c9f2..be017f54 100644 --- a/google/resumable_media/requests/download.py +++ b/google/resumable_media/requests/download.py @@ -667,17 +667,22 @@ def __init__(self, checksum): super().__init__() self._checksum = checksum - def decompress(self, data): + def decompress(self, data, max_length=-1): """Decompress the bytes. Args: data (bytes): The compressed bytes to be decompressed. + max_length (int): Maximum number of bytes to return. -1 for no + limit. Forwarded to the underlying decoder when supported. Returns: bytes: The decompressed bytes from ``data``. """ self._checksum.update(data) - return super().decompress(data) + try: + return super().decompress(data, max_length=max_length) + except TypeError: + return super().decompress(data) # urllib3.response.BrotliDecoder might not exist depending on whether brotli is @@ -703,17 +708,29 @@ def __init__(self, checksum): self._decoder = urllib3.response.BrotliDecoder() self._checksum = checksum - def decompress(self, data): + def decompress(self, data, max_length=-1): """Decompress the bytes. Args: data (bytes): The compressed bytes to be decompressed. + max_length (int): Maximum number of bytes to return. -1 for no + limit. Forwarded to the underlying decoder when supported. Returns: bytes: The decompressed bytes from ``data``. """ self._checksum.update(data) - return self._decoder.decompress(data) + try: + return self._decoder.decompress(data, max_length=max_length) + except TypeError: + return self._decoder.decompress(data) + + @property + def has_unconsumed_tail(self): + try: + return self._decoder.has_unconsumed_tail + except AttributeError: + return False def flush(self): return self._decoder.flush() diff --git a/tests/unit/requests/test_download.py b/tests/unit/requests/test_download.py index 713543cb..2678c614 100644 --- a/tests/unit/requests/test_download.py +++ b/tests/unit/requests/test_download.py @@ -1274,6 +1274,39 @@ def test_decompress(self): assert result == b"" md5_hash.update.assert_called_once_with(data) + def test_decompress_with_max_length(self): + md5_hash = mock.Mock(spec=["update"]) + decoder = download_mod._GzipDecoder(md5_hash) + + with mock.patch.object( + type(decoder).__bases__[0], "decompress" + ) as mock_super_decompress: + mock_super_decompress.return_value = b"decompressed" + data = b"\x1f\x8b\x08\x08" + result = decoder.decompress(data, max_length=10) + + assert result == b"decompressed" + md5_hash.update.assert_called_once_with(data) + mock_super_decompress.assert_called_once_with( + data, max_length=10 + ) + + def test_decompress_with_max_length_fallback(self): + md5_hash = mock.Mock(spec=["update"]) + decoder = download_mod._GzipDecoder(md5_hash) + + with mock.patch.object( + type(decoder).__bases__[0], + "decompress", + side_effect=[TypeError, b"decompressed"], + ) as mock_super_decompress: + data = b"\x1f\x8b\x08\x08" + result = decoder.decompress(data, max_length=10) + + assert result == b"decompressed" + md5_hash.update.assert_called_once_with(data) + assert mock_super_decompress.call_count == 2 + class Test_BrotliDecoder(object): def test_constructor(self): @@ -1290,6 +1323,46 @@ def test_decompress(self): assert result == b"" md5_hash.update.assert_called_once_with(data) + def test_decompress_with_max_length(self): + md5_hash = mock.Mock(spec=["update"]) + decoder = download_mod._BrotliDecoder(md5_hash) + + decoder._decoder = mock.Mock(spec=["decompress"]) + decoder._decoder.decompress.return_value = b"decompressed" + + data = b"compressed" + result = decoder.decompress(data, max_length=10) + + assert result == b"decompressed" + md5_hash.update.assert_called_once_with(data) + decoder._decoder.decompress.assert_called_once_with( + data, max_length=10 + ) + + def test_decompress_with_max_length_fallback(self): + md5_hash = mock.Mock(spec=["update"]) + decoder = download_mod._BrotliDecoder(md5_hash) + + decoder._decoder = mock.Mock(spec=["decompress"]) + decoder._decoder.decompress.side_effect = [TypeError, b"decompressed"] + + data = b"compressed" + result = decoder.decompress(data, max_length=10) + + assert result == b"decompressed" + md5_hash.update.assert_called_once_with(data) + assert decoder._decoder.decompress.call_count == 2 + + def test_has_unconsumed_tail(self): + decoder = download_mod._BrotliDecoder(mock.sentinel.md5_hash) + decoder._decoder = mock.Mock(spec=["has_unconsumed_tail"]) + decoder._decoder.has_unconsumed_tail = True + assert decoder.has_unconsumed_tail is True + + def test_has_unconsumed_tail_fallback(self): + decoder = download_mod._BrotliDecoder(mock.sentinel.md5_hash) + decoder._decoder = mock.Mock(spec=[]) + assert decoder.has_unconsumed_tail is False def _mock_response(status_code=http.client.OK, chunks=(), headers=None): if headers is None: diff --git a/tests_async/unit/requests/test_download.py b/tests_async/unit/requests/test_download.py index 6e3ef3fd..cdb0b581 100644 --- a/tests_async/unit/requests/test_download.py +++ b/tests_async/unit/requests/test_download.py @@ -761,6 +761,23 @@ def test_decompress(self): assert result == b"" md5_hash.update.assert_called_once_with(data) + def test_decompress_with_max_length(self): + md5_hash = mock.Mock(spec=["update"]) + decoder = download_mod._GzipDecoder(md5_hash) + + with mock.patch.object( + type(decoder).__bases__[0], "decompress" + ) as mock_super_decompress: + mock_super_decompress.return_value = b"decompressed" + data = b"\x1f\x8b\x08\x08" + result = decoder.decompress(data, max_length=10) + + assert result == b"decompressed" + md5_hash.update.assert_called_once_with(data) + mock_super_decompress.assert_called_once_with( + data, max_length=10 + ) + class AsyncIter: def __init__(self, items):