diff --git a/zcached/asyncio/connection.py b/zcached/asyncio/connection.py index 2374816..3e2eb7e 100644 --- a/zcached/asyncio/connection.py +++ b/zcached/asyncio/connection.py @@ -217,8 +217,10 @@ async def send(self, data: bytes) -> Result: logger.debug(f"{self.id} -> The connection has been terminated.") if not self.reconnect: return Result.fail(Errors.ConnectionClosed.value) - return await self.try_reconnect() + finally: + if self._pending_requests >= 1: + self._pending_requests -= 1 result: Result = await self.wait_for_response() if self.reconnect and result.error == Errors.ConnectionClosed: @@ -262,40 +264,33 @@ async def wait_for_response(self) -> Result: NOT TASK SAFE. """ if not self._reader: + logger.error( + f"{self.id} -> Missing StreamReader object! Did you forget to connect? " + f"Aborting the wait_for_response method..." + ) return Result.fail(Errors.ConnectionClosed.value) - complete_data: bytes = bytes() - try: - data: bytes | None = await self.receive(timeout_limit=self.timeout_limit) - if data is None: - self._connected = False - return Result.fail(Errors.ConnectionClosed.value) - except asyncio.TimeoutError: - return Result.fail(Errors.TimeoutLimit.value) + total_data: bytes = bytes() - complete_data += data - - while True: + while not total_data.endswith(b"\x03"): try: - data = await self.receive(timeout_limit=0.1) + data: bytes | None = await self.receive(timeout_limit=self.timeout_limit) except asyncio.TimeoutError: - break # Transfer complete. - if data is None or len(data) == 0: - # When socket lose connection to the server it receives empty bytes. - self._connected = False - return Result.fail(Errors.ConnectionClosed.value) + return Result.fail(Errors.TimeoutLimit.value) - complete_data += data + # When socket lose connection to the server it receives empty bytes. + # Or when the data is None, it means that the reader has been abandoned. + if data is None or len(data) == 0: # type: ignore + return Result.fail(Errors.ConnectionClosed.value) - if self._pending_requests >= 1: - self._pending_requests -= 1 + total_data += data # If the first byte is "-", it means that the response is an error. - if complete_data.startswith(b"-"): - error_message: str = complete_data.decode()[1:-2] + if total_data.startswith(b"-"): + error_message: str = total_data.decode()[1:-3] return Result.fail(error_message) - return Result.ok(complete_data) + return Result.ok(total_data[:-1]) async def close(self) -> None: """Closes the connection by closing the writer, and waiting until the writer is fully closed.""" diff --git a/zcached/connection.py b/zcached/connection.py index 0d3c6b0..4f7f74d 100644 --- a/zcached/connection.py +++ b/zcached/connection.py @@ -157,6 +157,8 @@ def receive(self) -> bytes | None: """ Method to receive the response from the server. None if there is no data in the socket yet. + + NOT THREAD SAFE. """ try: data: bytes = self.socket.recv(self.buffer_size) @@ -180,8 +182,6 @@ def send(self, data: bytes) -> Result: if self._lock.locked(): logging.debug(f"{self.id} -> Waiting for the thread lock to become available.") - self._pending_requests += 1 - with self._lock: try: logging.debug(f"{self.id} -> Sending data to the server -> %s", data) @@ -189,8 +189,10 @@ def send(self, data: bytes) -> Result: except (BrokenPipeError, OSError): if not self.reconnect: return Result.fail(Errors.ConnectionClosed.value) - return self.try_reconnect() + finally: + if self._pending_requests >= 1: + self._pending_requests -= 1 result: Result = self.wait_for_response() if not self.reconnect or result.error is None: @@ -227,45 +229,36 @@ def wait_for_response(self) -> Result: NOT THREAD SAFE. """ - backoff: ExponentialBackoff = ExponentialBackoff(0.1, 1.5, 0.5) + backoff: ExponentialBackoff = ExponentialBackoff(0.01, 3, 0.5) + total_data: bytes = bytes() - total_bytes: bytes = bytes() - transfer_complete: bool = False + # By doing this, we should receive the data at the first recv, without waiting for the backoff. + sleep(0.001) for timeout in backoff: data: bytes | None = self.receive() if not isinstance(data, bytes): - if len(total_bytes) > 0: - # If we already have some data, and this iteration gave us None, - # it means that the data transfer has been completed. - transfer_complete = True - else: - # We haven't received any data yet. - logging.debug(f"{self.id} -> There is no data in the socket. Timeout: {timeout}s.") - if backoff.total >= float(self.timeout_limit): - logging.error(f"{self.id} -> The waiting time limit for a response has been reached.") - return Result.fail(Errors.TimeoutLimit.value) - - sleep(timeout) - continue - - if transfer_complete: - if self._pending_requests >= 1: - self._pending_requests -= 1 + if backoff.total >= self.timeout_limit: + return Result.fail(Errors.TimeoutLimit.value) - # If the first byte is "-", it means that the response is an error. - if total_bytes.startswith(b"-"): - error_message: str = total_bytes.decode()[1:-2] - return Result.fail(error_message) - - return Result.ok(total_bytes) + logging.debug(f"{self.id} -> There is no data in the socket. Timeout: {timeout}s.") + sleep(timeout) + continue if len(data) == 0: # type: ignore # When socket lose connection to the server it receives empty bytes. return Result.fail(Errors.ConnectionClosed.value) - total_bytes += data # type: ignore + total_data += data # type: ignore + + if total_data.endswith(b"\x03"): # Received complete data. + # If the first byte is "-", it means that the response is an error. + if total_data.startswith(b"-"): + error_message: str = total_data.decode()[1:-3] + return Result.fail(error_message) + + return Result.ok(total_data[:-1]) # ExponentialBackoff should be increased only when we receive None. backoff.reset() diff --git a/zcached/enums.py b/zcached/enums.py index 63432a5..18105e1 100644 --- a/zcached/enums.py +++ b/zcached/enums.py @@ -36,16 +36,16 @@ def __repr__(self) -> str: class Commands(bytes, Enum): - PING = b"*1\r\n$4\r\nPING\r\n" - FLUSH = b"*1\r\n$5\r\nFLUSH\r\n" - DB_SIZE = b"*1\r\n$6\r\nDBSIZE\r\n" - SAVE = b"*1\r\n$4\r\nSAVE\r\n" - KEYS = b"*1\r\n$4\r\nKEYS\r\n" - LAST_SAVE = b"*1\r\n$8\r\nLASTSAVE\r\n" + PING = b"*1\r\n$4\r\nPING\r\n\x04" + FLUSH = b"*1\r\n$5\r\nFLUSH\r\n\x04" + DB_SIZE = b"*1\r\n$6\r\nDBSIZE\r\n\x04" + SAVE = b"*1\r\n$4\r\nSAVE\r\n\x04" + KEYS = b"*1\r\n$4\r\nKEYS\r\n\x04" + LAST_SAVE = b"*1\r\n$8\r\nLASTSAVE\r\n\x04" @staticmethod def get(key: str) -> bytes: - return f"*2\r\n$3\r\nGET\r\n${len(key)}\r\n{key}\r\n".encode() + return f"*2\r\n$3\r\nGET\r\n${len(key)}\r\n{key}\r\n\x04".encode() @staticmethod def mget(*keys: str) -> bytes: @@ -53,12 +53,12 @@ def mget(*keys: str) -> bytes: for key in keys: command += f"${len(key)}\r\n{key}\r\n" - return command.encode() + return (command + "\x04").encode() @staticmethod def set(key: str, value: SupportedTypes) -> bytes: serializer: Serializer = Serializer() - command: str = f"*3\r\n$3\r\nSET\r\n${len(key)}\r\n{key}\r\n{serializer.process(value)}" + command: str = f"*3\r\n$3\r\nSET\r\n${len(key)}\r\n{key}\r\n{serializer.process(value)}\x04" return command.encode() @staticmethod @@ -69,11 +69,11 @@ def mset(**params: SupportedTypes) -> bytes: for key, value in params.items(): command += f"${len(key)}\r\n{key}\r\n{serializer.process(value)}" - return command.encode() + return (command + "\x04").encode() @staticmethod def delete(key: str) -> bytes: - return f"*2\r\n$6\r\nDELETE\r\n${len(key)}\r\n{key}\r\n".encode() + return f"*2\r\n$6\r\nDELETE\r\n${len(key)}\r\n{key}\r\n\x04".encode() def __repr__(self) -> str: return f"{self.value}"