diff --git a/lib/backend.py b/lib/backend.py index 5cbb7ff..39ef7c7 100644 --- a/lib/backend.py +++ b/lib/backend.py @@ -9,6 +9,7 @@ from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional from functools import cached_property from distutils.util import strtobool +from collections import deque from anyio import open_file from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector @@ -30,7 +31,7 @@ BenchmarkResult ) -VERSION = "0.2.0" +VERSION = "0.2.1" MSG_HISTORY_LEN = 100 log = logging.getLogger(__file__) @@ -63,6 +64,7 @@ class Backend: version = VERSION msg_history = [] sem: Semaphore = dataclasses.field(default_factory=Semaphore) + queue: deque = dataclasses.field(default_factory=deque, repr=False) unsecured: bool = dataclasses.field( default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))), ) @@ -141,11 +143,26 @@ async def __handle_request( workload = payload.count_workload() request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created") - async def cancel_api_call_if_disconnected() -> web.Response: + + def advance_queue_after_completion(event: asyncio.Event): + """Pop current head and wake next waiter, if any.""" + # If this event is current head, wake next waiter + if self.queue and self.queue[0] is event: + self.queue.popleft() + if self.queue: + self.queue[0].set() + else: + # Else, remove it from the queue + try: + self.queue.remove(event) + except ValueError: + pass + + async def cancel_api_call_if_disconnected() -> None: await request.wait_for_disconnection() - log.debug(f"request with reqnum: {request_metrics.reqnum} was canceled") + log.debug(f"Request with reqnum: {request_metrics.reqnum} was canceled") self.metrics._request_canceled(request_metrics) - raise asyncio.CancelledError + return async def make_request() -> Union[web.Response, web.StreamResponse]: try: @@ -162,7 +179,9 @@ async def make_request() -> Union[web.Response, web.StreamResponse]: res = await handler.generate_client_response(request, response) self.metrics._request_success(request_metrics) return res - except requests.exceptions.RequestException as e: + except asyncio.CancelledError: + raise + except Exception as e: log.debug(f"[backend] Request error: {e}") self.metrics._request_errored(request_metrics) return web.Response(status=500) @@ -177,46 +196,87 @@ async def make_request() -> Union[web.Response, web.StreamResponse]: self.metrics._request_reject(request_metrics) return web.Response(status=429) - acquired = False + disconnect_task = create_task(cancel_api_call_if_disconnected()) + next_request_task = None + work_task = None + event = asyncio.Event() # Used in finally block, so initialize here + + self.metrics._request_start(request_metrics) + try: - self.metrics._request_start(request_metrics) - if self.allow_parallel_requests is False: - log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}") - await self.sem.acquire() - acquired = True - log.debug( - f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..." - ) - else: + if self.allow_parallel_requests: log.debug(f"Starting request for reqnum:{request_metrics.reqnum}") - done, pending = await wait( - [ - create_task(make_request()), - create_task(cancel_api_call_if_disconnected()), - ], - return_when=FIRST_COMPLETED, - ) - for t in pending: - t.cancel() - await asyncio.gather(*pending, return_exceptions=True) + work_task = create_task(make_request()) + done, pending = await wait([work_task, disconnect_task], return_when=FIRST_COMPLETED) - done_task = done.pop() - try: - return done_task.result() - except Exception as e: - log.debug(f"Request task raised exception: {e}") - return web.Response(status=500) + for t in pending: + t.cancel() + await asyncio.gather(*pending, return_exceptions=True) + + if disconnect_task in done: + return web.Response(status=499) + + # otherwise work_task completed + return await work_task + + # FIFO-queue branch + else: + # Insert a Event into the queue for this request + # Event.set() == our request is up next + self.queue.append(event) + if self.queue and self.queue[0] is event: + event.set() + + # Race between our request being next and request being cancelled + next_request_task = create_task(event.wait()) + first_done, first_pending = await wait( + [next_request_task, disconnect_task], return_when=FIRST_COMPLETED + ) + + # If the disconnect task wins the race + if disconnect_task in first_done: + # Clean up the next_request_task, then exit + for t in first_pending: + t.cancel() + await asyncio.gather(*first_pending, return_exceptions=True) + return web.Response(status=499) + + # We are the next-up request in the queue + log.debug(f"Starting work on request {request_metrics.reqnum}...") + + # Race the backend API call with the disconnect task + work_task = create_task(make_request()) + + done, pending = await wait([work_task, disconnect_task], return_when=FIRST_COMPLETED) + for t in pending: + t.cancel() + await asyncio.gather(*pending, return_exceptions=True) + + if disconnect_task in done: + return web.Response(status=499) + + # otherwise work_task completed + return await work_task + except asyncio.CancelledError: - # Client is gone. Do not write a response; just unwind. - return web.Response(status=499) + return web.Response(status=499) + except Exception as e: log.debug(f"Exception in main handler loop {e}") return web.Response(status=500) + finally: - # Always release the semaphore if it was acquired - if acquired: - self.sem.release() + if not self.allow_parallel_requests: + advance_queue_after_completion(event) + self.metrics._request_end(request_metrics) + cleanup_tasks = [t for t in (next_request_task, work_task, disconnect_task) if t] + for t in cleanup_tasks: + if not t.done(): + t.cancel() + if cleanup_tasks: + await asyncio.gather(*cleanup_tasks, return_exceptions=True) + @cached_property def healthcheck_session(self):