Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 19 additions & 24 deletions zcached/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,10 @@ async def send(self, data: bytes) -> Result:
logger.debug(f"{self.id} -> The connection has been terminated.")
if not self.reconnect:
return Result.fail(Errors.ConnectionClosed.value)

return await self.try_reconnect()
finally:
if self._pending_requests >= 1:
self._pending_requests -= 1

result: Result = await self.wait_for_response()
if self.reconnect and result.error == Errors.ConnectionClosed:
Expand Down Expand Up @@ -262,40 +264,33 @@ async def wait_for_response(self) -> Result:
NOT TASK SAFE.
"""
if not self._reader:
logger.error(
f"{self.id} -> Missing StreamReader object! Did you forget to connect? "
f"Aborting the wait_for_response method..."
)
return Result.fail(Errors.ConnectionClosed.value)

complete_data: bytes = bytes()
try:
data: bytes | None = await self.receive(timeout_limit=self.timeout_limit)
if data is None:
self._connected = False
return Result.fail(Errors.ConnectionClosed.value)
except asyncio.TimeoutError:
return Result.fail(Errors.TimeoutLimit.value)
total_data: bytes = bytes()

complete_data += data

while True:
while not total_data.endswith(b"\x03"):
try:
data = await self.receive(timeout_limit=0.1)
data: bytes | None = await self.receive(timeout_limit=self.timeout_limit)
except asyncio.TimeoutError:
break # Transfer complete.
if data is None or len(data) == 0:
# When socket lose connection to the server it receives empty bytes.
self._connected = False
return Result.fail(Errors.ConnectionClosed.value)
return Result.fail(Errors.TimeoutLimit.value)

complete_data += data
# When socket lose connection to the server it receives empty bytes.
# Or when the data is None, it means that the reader has been abandoned.
if data is None or len(data) == 0: # type: ignore
return Result.fail(Errors.ConnectionClosed.value)

if self._pending_requests >= 1:
self._pending_requests -= 1
total_data += data

# If the first byte is "-", it means that the response is an error.
if complete_data.startswith(b"-"):
error_message: str = complete_data.decode()[1:-2]
if total_data.startswith(b"-"):
error_message: str = total_data.decode()[1:-3]
return Result.fail(error_message)

return Result.ok(complete_data)
return Result.ok(total_data[:-1])

async def close(self) -> None:
"""Closes the connection by closing the writer, and waiting until the writer is fully closed."""
Expand Down
53 changes: 23 additions & 30 deletions zcached/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ def receive(self) -> bytes | None:
"""
Method to receive the response from the server.
None if there is no data in the socket yet.

NOT THREAD SAFE.
"""
try:
data: bytes = self.socket.recv(self.buffer_size)
Expand All @@ -180,17 +182,17 @@ def send(self, data: bytes) -> Result:
if self._lock.locked():
logging.debug(f"{self.id} -> Waiting for the thread lock to become available.")

self._pending_requests += 1

with self._lock:
try:
logging.debug(f"{self.id} -> Sending data to the server -> %s", data)
self.socket.send(data)
except (BrokenPipeError, OSError):
if not self.reconnect:
return Result.fail(Errors.ConnectionClosed.value)

return self.try_reconnect()
finally:
if self._pending_requests >= 1:
self._pending_requests -= 1

result: Result = self.wait_for_response()
if not self.reconnect or result.error is None:
Expand Down Expand Up @@ -227,45 +229,36 @@ def wait_for_response(self) -> Result:

NOT THREAD SAFE.
"""
backoff: ExponentialBackoff = ExponentialBackoff(0.1, 1.5, 0.5)
backoff: ExponentialBackoff = ExponentialBackoff(0.01, 3, 0.5)
total_data: bytes = bytes()

total_bytes: bytes = bytes()
transfer_complete: bool = False
# By doing this, we should receive the data at the first recv, without waiting for the backoff.
sleep(0.001)

for timeout in backoff:
data: bytes | None = self.receive()

if not isinstance(data, bytes):
if len(total_bytes) > 0:
# If we already have some data, and this iteration gave us None,
# it means that the data transfer has been completed.
transfer_complete = True
else:
# We haven't received any data yet.
logging.debug(f"{self.id} -> There is no data in the socket. Timeout: {timeout}s.")
if backoff.total >= float(self.timeout_limit):
logging.error(f"{self.id} -> The waiting time limit for a response has been reached.")
return Result.fail(Errors.TimeoutLimit.value)

sleep(timeout)
continue

if transfer_complete:
if self._pending_requests >= 1:
self._pending_requests -= 1
if backoff.total >= self.timeout_limit:
return Result.fail(Errors.TimeoutLimit.value)

# If the first byte is "-", it means that the response is an error.
if total_bytes.startswith(b"-"):
error_message: str = total_bytes.decode()[1:-2]
return Result.fail(error_message)

return Result.ok(total_bytes)
logging.debug(f"{self.id} -> There is no data in the socket. Timeout: {timeout}s.")
sleep(timeout)
continue

if len(data) == 0: # type: ignore
# When socket lose connection to the server it receives empty bytes.
return Result.fail(Errors.ConnectionClosed.value)

total_bytes += data # type: ignore
total_data += data # type: ignore

if total_data.endswith(b"\x03"): # Received complete data.
# If the first byte is "-", it means that the response is an error.
if total_data.startswith(b"-"):
error_message: str = total_data.decode()[1:-3]
return Result.fail(error_message)

return Result.ok(total_data[:-1])

# ExponentialBackoff should be increased only when we receive None.
backoff.reset()
Expand Down
22 changes: 11 additions & 11 deletions zcached/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,29 +36,29 @@ def __repr__(self) -> str:


class Commands(bytes, Enum):
PING = b"*1\r\n$4\r\nPING\r\n"
FLUSH = b"*1\r\n$5\r\nFLUSH\r\n"
DB_SIZE = b"*1\r\n$6\r\nDBSIZE\r\n"
SAVE = b"*1\r\n$4\r\nSAVE\r\n"
KEYS = b"*1\r\n$4\r\nKEYS\r\n"
LAST_SAVE = b"*1\r\n$8\r\nLASTSAVE\r\n"
PING = b"*1\r\n$4\r\nPING\r\n\x04"
FLUSH = b"*1\r\n$5\r\nFLUSH\r\n\x04"
DB_SIZE = b"*1\r\n$6\r\nDBSIZE\r\n\x04"
SAVE = b"*1\r\n$4\r\nSAVE\r\n\x04"
KEYS = b"*1\r\n$4\r\nKEYS\r\n\x04"
LAST_SAVE = b"*1\r\n$8\r\nLASTSAVE\r\n\x04"

@staticmethod
def get(key: str) -> bytes:
return f"*2\r\n$3\r\nGET\r\n${len(key)}\r\n{key}\r\n".encode()
return f"*2\r\n$3\r\nGET\r\n${len(key)}\r\n{key}\r\n\x04".encode()

@staticmethod
def mget(*keys: str) -> bytes:
command: str = f"*{1 + len(keys)}\r\n$4\r\nMGET\r\n"
for key in keys:
command += f"${len(key)}\r\n{key}\r\n"

return command.encode()
return (command + "\x04").encode()

@staticmethod
def set(key: str, value: SupportedTypes) -> bytes:
serializer: Serializer = Serializer()
command: str = f"*3\r\n$3\r\nSET\r\n${len(key)}\r\n{key}\r\n{serializer.process(value)}"
command: str = f"*3\r\n$3\r\nSET\r\n${len(key)}\r\n{key}\r\n{serializer.process(value)}\x04"
return command.encode()

@staticmethod
Expand All @@ -69,11 +69,11 @@ def mset(**params: SupportedTypes) -> bytes:
for key, value in params.items():
command += f"${len(key)}\r\n{key}\r\n{serializer.process(value)}"

return command.encode()
return (command + "\x04").encode()

@staticmethod
def delete(key: str) -> bytes:
return f"*2\r\n$6\r\nDELETE\r\n${len(key)}\r\n{key}\r\n".encode()
return f"*2\r\n$6\r\nDELETE\r\n${len(key)}\r\n{key}\r\n\x04".encode()

def __repr__(self) -> str:
return f"{self.value}"