diff --git a/google/_async_resumable_media/requests/download.py b/google/_async_resumable_media/requests/download.py index d4af79d9..4781c875 100644 --- a/google/_async_resumable_media/requests/download.py +++ b/google/_async_resumable_media/requests/download.py @@ -452,14 +452,21 @@ def __init__(self, checksum): super(_GzipDecoder, self).__init__() self._checksum = checksum - def decompress(self, data): + def decompress(self, data, max_length=-1): """Decompress the bytes. Args: data (bytes): The compressed bytes to be decompressed. + max_length (int): Maximum number of bytes to return. -1 for no + limit. Forwarded to the underlying decoder when supported. Returns: bytes: The decompressed bytes from ``data``. """ self._checksum.update(data) - return super(_GzipDecoder, self).decompress(data) + try: + return super(_GzipDecoder, self).decompress( + data, max_length=max_length + ) + except TypeError: + return super(_GzipDecoder, self).decompress(data) diff --git a/google/resumable_media/requests/download.py b/google/resumable_media/requests/download.py index 1472c9f2..be017f54 100644 --- a/google/resumable_media/requests/download.py +++ b/google/resumable_media/requests/download.py @@ -667,17 +667,22 @@ def __init__(self, checksum): super().__init__() self._checksum = checksum - def decompress(self, data): + def decompress(self, data, max_length=-1): """Decompress the bytes. Args: data (bytes): The compressed bytes to be decompressed. + max_length (int): Maximum number of bytes to return. -1 for no + limit. Forwarded to the underlying decoder when supported. Returns: bytes: The decompressed bytes from ``data``. """ self._checksum.update(data) - return super().decompress(data) + try: + return super().decompress(data, max_length=max_length) + except TypeError: + return super().decompress(data) # urllib3.response.BrotliDecoder might not exist depending on whether brotli is @@ -703,17 +708,29 @@ def __init__(self, checksum): self._decoder = urllib3.response.BrotliDecoder() self._checksum = checksum - def decompress(self, data): + def decompress(self, data, max_length=-1): """Decompress the bytes. Args: data (bytes): The compressed bytes to be decompressed. + max_length (int): Maximum number of bytes to return. -1 for no + limit. Forwarded to the underlying decoder when supported. Returns: bytes: The decompressed bytes from ``data``. """ self._checksum.update(data) - return self._decoder.decompress(data) + try: + return self._decoder.decompress(data, max_length=max_length) + except TypeError: + return self._decoder.decompress(data) + + @property + def has_unconsumed_tail(self): + try: + return self._decoder.has_unconsumed_tail + except AttributeError: + return False def flush(self): return self._decoder.flush() diff --git a/tests/unit/requests/test_download.py b/tests/unit/requests/test_download.py index 713543cb..2678c614 100644 --- a/tests/unit/requests/test_download.py +++ b/tests/unit/requests/test_download.py @@ -1274,6 +1274,39 @@ def test_decompress(self): assert result == b"" md5_hash.update.assert_called_once_with(data) + def test_decompress_with_max_length(self): + md5_hash = mock.Mock(spec=["update"]) + decoder = download_mod._GzipDecoder(md5_hash) + + with mock.patch.object( + type(decoder).__bases__[0], "decompress" + ) as mock_super_decompress: + mock_super_decompress.return_value = b"decompressed" + data = b"\x1f\x8b\x08\x08" + result = decoder.decompress(data, max_length=10) + + assert result == b"decompressed" + md5_hash.update.assert_called_once_with(data) + mock_super_decompress.assert_called_once_with( + data, max_length=10 + ) + + def test_decompress_with_max_length_fallback(self): + md5_hash = mock.Mock(spec=["update"]) + decoder = download_mod._GzipDecoder(md5_hash) + + with mock.patch.object( + type(decoder).__bases__[0], + "decompress", + side_effect=[TypeError, b"decompressed"], + ) as mock_super_decompress: + data = b"\x1f\x8b\x08\x08" + result = decoder.decompress(data, max_length=10) + + assert result == b"decompressed" + md5_hash.update.assert_called_once_with(data) + assert mock_super_decompress.call_count == 2 + class Test_BrotliDecoder(object): def test_constructor(self): @@ -1290,6 +1323,46 @@ def test_decompress(self): assert result == b"" md5_hash.update.assert_called_once_with(data) + def test_decompress_with_max_length(self): + md5_hash = mock.Mock(spec=["update"]) + decoder = download_mod._BrotliDecoder(md5_hash) + + decoder._decoder = mock.Mock(spec=["decompress"]) + decoder._decoder.decompress.return_value = b"decompressed" + + data = b"compressed" + result = decoder.decompress(data, max_length=10) + + assert result == b"decompressed" + md5_hash.update.assert_called_once_with(data) + decoder._decoder.decompress.assert_called_once_with( + data, max_length=10 + ) + + def test_decompress_with_max_length_fallback(self): + md5_hash = mock.Mock(spec=["update"]) + decoder = download_mod._BrotliDecoder(md5_hash) + + decoder._decoder = mock.Mock(spec=["decompress"]) + decoder._decoder.decompress.side_effect = [TypeError, b"decompressed"] + + data = b"compressed" + result = decoder.decompress(data, max_length=10) + + assert result == b"decompressed" + md5_hash.update.assert_called_once_with(data) + assert decoder._decoder.decompress.call_count == 2 + + def test_has_unconsumed_tail(self): + decoder = download_mod._BrotliDecoder(mock.sentinel.md5_hash) + decoder._decoder = mock.Mock(spec=["has_unconsumed_tail"]) + decoder._decoder.has_unconsumed_tail = True + assert decoder.has_unconsumed_tail is True + + def test_has_unconsumed_tail_fallback(self): + decoder = download_mod._BrotliDecoder(mock.sentinel.md5_hash) + decoder._decoder = mock.Mock(spec=[]) + assert decoder.has_unconsumed_tail is False def _mock_response(status_code=http.client.OK, chunks=(), headers=None): if headers is None: diff --git a/tests_async/unit/requests/test_download.py b/tests_async/unit/requests/test_download.py index 6e3ef3fd..cdb0b581 100644 --- a/tests_async/unit/requests/test_download.py +++ b/tests_async/unit/requests/test_download.py @@ -761,6 +761,23 @@ def test_decompress(self): assert result == b"" md5_hash.update.assert_called_once_with(data) + def test_decompress_with_max_length(self): + md5_hash = mock.Mock(spec=["update"]) + decoder = download_mod._GzipDecoder(md5_hash) + + with mock.patch.object( + type(decoder).__bases__[0], "decompress" + ) as mock_super_decompress: + mock_super_decompress.return_value = b"decompressed" + data = b"\x1f\x8b\x08\x08" + result = decoder.decompress(data, max_length=10) + + assert result == b"decompressed" + md5_hash.update.assert_called_once_with(data) + mock_super_decompress.assert_called_once_with( + data, max_length=10 + ) + class AsyncIter: def __init__(self, items):