Skip to content

Commit 8aba092

Browse files
committed
addressing comments:
1 parent 0b4e6a7 commit 8aba092

File tree

2 files changed

+85
-101
lines changed

2 files changed

+85
-101
lines changed

google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,14 @@
2525
class _WriteState:
2626
"""A helper class to track the state of a single upload operation.
2727
28-
Attributes:
29-
spec (AppendObjectSpec): The specification for the object to write.
30-
chunk_size (int): The size of chunks to read from the buffer.
31-
user_buffer (IO[bytes]): The data source.
32-
persisted_size (int): The amount of data confirmed as persisted by the server.
33-
bytes_sent (int): The amount of data currently sent in the active stream.
34-
write_handle (bytes | BidiWriteHandle | None): The handle for the append session.
35-
routing_token (str | None): Token for routing to the correct backend.
36-
is_complete (bool): Whether the upload has finished.
28+
:type spec: :class:`google.cloud.storage_v2.types.AppendObjectSpec`
29+
:param spec: The specification for the object to write.
30+
31+
:type chunk_size: int
32+
:param chunk_size: The size of chunks to write to the server.
33+
34+
:type user_buffer: IO[bytes]
35+
:param user_buffer: The data source.
3736
"""
3837

3938
def __init__(
@@ -71,23 +70,29 @@ def generate_requests(
7170
if write_state.routing_token:
7271
write_state.spec.routing_token = write_state.routing_token
7372

73+
# Initial request of the stream must provide the specification.
74+
# If we have a write_handle, we request a state lookup to verify persisted offset.
7475
do_state_lookup = write_state.write_handle is not None
75-
yield storage_type.BidiWriteObjectRequest(
76-
append_object_spec=write_state.spec, state_lookup=do_state_lookup
76+
77+
# Determine if we need to send WriteObjectSpec or AppendObjectSpec
78+
initial_request = storage_type.BidiWriteObjectRequest(
79+
state_lookup=do_state_lookup
7780
)
7881

82+
if isinstance(write_state.spec, storage_type.WriteObjectSpec):
83+
initial_request.write_object_spec = write_state.spec
84+
else:
85+
initial_request.append_object_spec = write_state.spec
86+
87+
yield initial_request
88+
7989
# The buffer should already be seeked to the correct position (persisted_size)
8090
# by the `recover_state_on_failure` method before this is called.
8191
while not write_state.is_complete:
8292
chunk = write_state.user_buffer.read(write_state.chunk_size)
8393

8494
# End of File detection
8595
if not chunk:
86-
write_state.is_complete = True
87-
yield storage_type.BidiWriteObjectRequest(
88-
write_offset=write_state.bytes_sent,
89-
finish_write=True,
90-
)
9196
return
9297

9398
checksummed_data = storage_type.ChecksummedData(content=chunk)
@@ -122,19 +127,25 @@ def update_state_from_response(
122127
async def recover_state_on_failure(
123128
self, error: Exception, state: Dict[str, Any]
124129
) -> None:
125-
"""Handles errors, specifically BidiWriteObjectRedirectedError, and rewinds state."""
130+
"""
131+
Handles errors, specifically BidiWriteObjectRedirectedError, and rewinds state.
132+
133+
This method rewinds the user buffer and internal byte tracking to the
134+
last confirmed 'persisted_size' from the server.
135+
"""
126136
write_state: _WriteState = state["write_state"]
127137
cause = getattr(error, "cause", error)
128138

129-
# Extract routing token and potentially a new write handle.
139+
# Extract routing token and potentially a new write handle for redirection.
130140
if isinstance(cause, BidiWriteObjectRedirectedError):
131141
if cause.routing_token:
132142
write_state.routing_token = cause.routing_token
133143

134-
if hasattr(cause, "write_handle") and cause.write_handle:
135-
write_state.write_handle = cause.write_handle
144+
redirect_handle = getattr(cause, "write_handle", None)
145+
if redirect_handle:
146+
write_state.write_handle = redirect_handle
136147

137148
# We must assume any data sent beyond 'persisted_size' was lost.
138-
# Reset the user buffer to the last known good byte.
149+
# Reset the user buffer to the last known good byte confirmed by the server.
139150
write_state.user_buffer.seek(write_state.persisted_size)
140151
write_state.bytes_sent = write_state.persisted_size

tests/unit/asyncio/retry/test_writes_resumption_strategy.py

Lines changed: 53 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -38,72 +38,86 @@ def test_ctor(self):
3838
strategy = self._make_one()
3939
self.assertIsInstance(strategy, self._get_target_class())
4040

41-
def test_generate_requests_initial(self):
41+
def test_generate_requests_initial_new_object(self):
4242
"""
43-
Verify the initial request sequence for a new upload.
44-
- First request is AppendObjectSpec with state_lookup=False.
45-
- Following requests are data chunks.
46-
- Final request has finish_write=True.
43+
Verify the initial request sequence for a new upload (WriteObjectSpec).
4744
"""
4845
strategy = self._make_one()
4946
mock_buffer = io.BytesIO(b"0123456789")
50-
mock_spec = storage_type.AppendObjectSpec(object_="test-object")
47+
# Use WriteObjectSpec for new objects
48+
mock_spec = storage_type.WriteObjectSpec(
49+
resource=storage_type.Object(name="test-object")
50+
)
5151
state = {
5252
"write_state": _WriteState(
5353
mock_spec, chunk_size=4, user_buffer=mock_buffer
5454
),
55-
"first_request": True,
5655
}
5756

5857
requests = list(strategy.generate_requests(state))
5958

60-
self.assertEqual(requests[0].append_object_spec, mock_spec)
59+
# Check first request (Spec)
60+
self.assertEqual(requests[0].write_object_spec, mock_spec)
6161
self.assertFalse(requests[0].state_lookup)
62-
self.assertFalse(requests[0].append_object_spec.write_handle)
6362

63+
# Check data chunks
6464
self.assertEqual(requests[1].write_offset, 0)
6565
self.assertEqual(requests[1].checksummed_data.content, b"0123")
6666
self.assertEqual(requests[2].write_offset, 4)
6767
self.assertEqual(requests[2].checksummed_data.content, b"4567")
6868
self.assertEqual(requests[3].write_offset, 8)
6969
self.assertEqual(requests[3].checksummed_data.content, b"89")
7070

71-
self.assertEqual(requests[4].write_offset, 10)
72-
self.assertTrue(requests[4].finish_write)
73-
74-
self.assertEqual(len(requests), 5)
71+
# Total requests: 1 Spec + 3 Chunks
72+
self.assertEqual(len(requests), 4)
7573

76-
def test_generate_requests_empty_file(self):
74+
def test_generate_requests_initial_existing_object(self):
7775
"""
78-
Verify the request sequence for an empty file upload.
79-
- First request is AppendObjectSpec.
80-
- Second and final request has finish_write=True.
76+
Verify the initial request sequence for appending to an existing object (AppendObjectSpec).
8177
"""
8278
strategy = self._make_one()
83-
mock_buffer = io.BytesIO(b"")
84-
mock_spec = storage_type.AppendObjectSpec(object_="test-object")
79+
mock_buffer = io.BytesIO(b"0123")
80+
# Use AppendObjectSpec for existing objects
81+
mock_spec = storage_type.AppendObjectSpec(
82+
object_="test-object", bucket="test-bucket"
83+
)
8584
state = {
8685
"write_state": _WriteState(
8786
mock_spec, chunk_size=4, user_buffer=mock_buffer
8887
),
89-
"first_request": True,
9088
}
9189

9290
requests = list(strategy.generate_requests(state))
9391

92+
# Check first request (Spec)
9493
self.assertEqual(requests[0].append_object_spec, mock_spec)
9594
self.assertFalse(requests[0].state_lookup)
9695

96+
# Check data chunk
9797
self.assertEqual(requests[1].write_offset, 0)
98-
self.assertTrue(requests[1].finish_write)
98+
self.assertEqual(requests[1].checksummed_data.content, b"0123")
99+
100+
def test_generate_requests_empty_file(self):
101+
"""
102+
Verify request sequence for an empty file. Should just be the Spec.
103+
"""
104+
strategy = self._make_one()
105+
mock_buffer = io.BytesIO(b"")
106+
mock_spec = storage_type.AppendObjectSpec(object_="test-object")
107+
state = {
108+
"write_state": _WriteState(
109+
mock_spec, chunk_size=4, user_buffer=mock_buffer
110+
),
111+
}
112+
113+
requests = list(strategy.generate_requests(state))
99114

100-
self.assertEqual(len(requests), 2)
115+
self.assertEqual(len(requests), 1)
116+
self.assertEqual(requests[0].append_object_spec, mock_spec)
101117

102118
def test_generate_requests_resumption(self):
103119
"""
104120
Verify request sequence when resuming an upload.
105-
- First request is AppendObjectSpec with write_handle and state_lookup=True.
106-
- Data streaming starts from the persisted_size.
107121
"""
108122
strategy = self._make_one()
109123
mock_buffer = io.BytesIO(b"0123456789")
@@ -115,34 +129,26 @@ def test_generate_requests_resumption(self):
115129
write_state.write_handle = storage_type.BidiWriteHandle(handle=b"test-handle")
116130
mock_buffer.seek(4)
117131

118-
state = {"write_state": write_state, "first_request": True}
132+
state = {"write_state": write_state}
119133

120134
requests = list(strategy.generate_requests(state))
121135

136+
# Check first request has handle and lookup
122137
self.assertEqual(
123138
requests[0].append_object_spec.write_handle.handle, b"test-handle"
124139
)
125140
self.assertTrue(requests[0].state_lookup)
126141

142+
# Check data starts from offset 4
127143
self.assertEqual(requests[1].write_offset, 4)
128144
self.assertEqual(requests[1].checksummed_data.content, b"4567")
129145
self.assertEqual(requests[2].write_offset, 8)
130146
self.assertEqual(requests[2].checksummed_data.content, b"89")
131147

132-
self.assertEqual(requests[3].write_offset, 10)
133-
self.assertTrue(requests[3].finish_write)
134-
135-
self.assertEqual(len(requests), 4)
136-
137148
@pytest.mark.asyncio
138149
async def test_generate_requests_after_failure_and_recovery(self):
139150
"""
140-
Verify a complex scenario:
141-
1. Start upload.
142-
2. Receive a persisted_size update.
143-
3. Encounter an error.
144-
4. Recover state.
145-
5. Generate new requests for resumption.
151+
Verify recovery and resumption flow.
146152
"""
147153
strategy = self._make_one()
148154
mock_buffer = io.BytesIO(b"0123456789abcdef")
@@ -157,12 +163,11 @@ async def test_generate_requests_after_failure_and_recovery(self):
157163

158164
strategy.update_state_from_response(
159165
storage_type.BidiWriteObjectResponse(
160-
persisted_size=4, write_handle=b"handle-1"
166+
persisted_size=4,
167+
write_handle=storage_type.BidiWriteHandle(handle=b"handle-1"),
161168
),
162169
state,
163170
)
164-
self.assertEqual(write_state.persisted_size, 4)
165-
self.assertEqual(write_state.write_handle, b"handle-1")
166171

167172
await strategy.recover_state_on_failure(Exception("network error"), state)
168173

@@ -172,23 +177,15 @@ async def test_generate_requests_after_failure_and_recovery(self):
172177
requests = list(strategy.generate_requests(state))
173178

174179
self.assertTrue(requests[0].state_lookup)
175-
self.assertEqual(requests[0].append_object_spec.write_handle, b"handle-1")
180+
self.assertEqual(
181+
requests[0].append_object_spec.write_handle.handle, b"handle-1"
182+
)
176183

177184
self.assertEqual(requests[1].write_offset, 4)
178185
self.assertEqual(requests[1].checksummed_data.content, b"4567")
179-
self.assertEqual(requests[2].write_offset, 8)
180-
self.assertEqual(requests[2].checksummed_data.content, b"89ab")
181-
self.assertEqual(requests[3].write_offset, 12)
182-
self.assertEqual(requests[3].checksummed_data.content, b"cdef")
183-
184-
self.assertEqual(requests[4].write_offset, 16)
185-
self.assertTrue(requests[4].finish_write)
186-
self.assertEqual(len(requests), 5)
187186

188187
def test_update_state_from_response(self):
189-
"""
190-
Verify that the write state is correctly updated based on server responses.
191-
"""
188+
"""Verify state updates from server responses."""
192189
strategy = self._make_one()
193190
mock_buffer = io.BytesIO(b"0123456789")
194191
mock_spec = storage_type.AppendObjectSpec(object_="test-object")
@@ -232,30 +229,10 @@ def test_update_state_from_response_ignores_smaller_persisted_size(self):
232229

233230
self.assertEqual(write_state.persisted_size, 2048)
234231

235-
@pytest.mark.asyncio
236-
async def test_recover_state_on_failure_rewinds_state(self):
237-
"""
238-
Verify that on failure, the buffer is seeked to persisted_size
239-
and bytes_sent is reset.
240-
"""
241-
strategy = self._make_one()
242-
mock_buffer = mock.MagicMock(spec=io.BytesIO)
243-
mock_spec = storage_type.AppendObjectSpec(object_="test-object")
244-
245-
write_state = _WriteState(mock_spec, chunk_size=4, user_buffer=mock_buffer)
246-
write_state.persisted_size = 100
247-
write_state.bytes_sent = 200
248-
state = {"write_state": write_state}
249-
250-
await strategy.recover_state_on_failure(Exception("any error"), state)
251-
252-
mock_buffer.seek.assert_called_once_with(100)
253-
self.assertEqual(write_state.bytes_sent, 100)
254-
255232
@pytest.mark.asyncio
256233
async def test_recover_state_on_failure_handles_redirect(self):
257234
"""
258-
Verify that on a redirect error, the routing_token is extracted and stored.
235+
Verify redirection error handling.
259236
"""
260237
strategy = self._make_one()
261238
mock_buffer = mock.MagicMock(spec=io.BytesIO)
@@ -271,11 +248,11 @@ async def test_recover_state_on_failure_handles_redirect(self):
271248
await strategy.recover_state_on_failure(wrapped_error, state)
272249

273250
self.assertEqual(write_state.routing_token, "new-token-123")
274-
mock_buffer.seek.assert_called_once_with(0)
275-
self.assertEqual(write_state.bytes_sent, 0)
251+
mock_buffer.seek.assert_called_with(0)
276252

277253
@pytest.mark.asyncio
278254
async def test_recover_state_on_failure_handles_redirect_with_handle(self):
255+
"""Verify redirection that includes a write handle."""
279256
strategy = self._make_one()
280257
mock_buffer = mock.MagicMock(spec=io.BytesIO)
281258
mock_spec = storage_type.AppendObjectSpec(object_="test-object")
@@ -294,8 +271,7 @@ async def test_recover_state_on_failure_handles_redirect_with_handle(self):
294271
self.assertEqual(write_state.routing_token, "new-token-456")
295272
self.assertEqual(write_state.write_handle, b"redirect-handle")
296273

297-
mock_buffer.seek.assert_called_once_with(0)
298-
self.assertEqual(write_state.bytes_sent, 0)
274+
mock_buffer.seek.assert_called_with(0)
299275

300276
def test_generate_requests_sends_crc32c_checksum(self):
301277
strategy = self._make_one()
@@ -305,13 +281,10 @@ def test_generate_requests_sends_crc32c_checksum(self):
305281
"write_state": _WriteState(
306282
mock_spec, chunk_size=4, user_buffer=mock_buffer
307283
),
308-
"first_request": True,
309284
}
310285

311286
requests = list(strategy.generate_requests(state))
312287

313-
self.assertEqual(len(requests), 3)
314-
315288
expected_crc = google_crc32c.Checksum(b"0123")
316289
expected_crc_int = int.from_bytes(expected_crc.digest(), "big")
317290
self.assertEqual(requests[1].checksummed_data.crc32c, expected_crc_int)
@@ -323,7 +296,7 @@ def test_generate_requests_with_routing_token(self):
323296

324297
write_state = _WriteState(mock_spec, chunk_size=4, user_buffer=mock_buffer)
325298
write_state.routing_token = "redirected-token"
326-
state = {"write_state": write_state, "first_request": True}
299+
state = {"write_state": write_state}
327300

328301
requests = list(strategy.generate_requests(state))
329302

0 commit comments

Comments
 (0)