Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 96 additions & 36 deletions lib/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
from functools import cached_property
from distutils.util import strtobool
from collections import deque

from anyio import open_file
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector
Expand All @@ -30,7 +31,7 @@
BenchmarkResult
)

VERSION = "0.2.0"
VERSION = "0.2.1"

MSG_HISTORY_LEN = 100
log = logging.getLogger(__file__)
Expand Down Expand Up @@ -63,6 +64,7 @@ class Backend:
version = VERSION
msg_history = []
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
queue: deque = dataclasses.field(default_factory=deque, repr=False)
unsecured: bool = dataclasses.field(
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
)
Expand Down Expand Up @@ -141,11 +143,26 @@ async def __handle_request(
workload = payload.count_workload()
request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")

async def cancel_api_call_if_disconnected() -> web.Response:

def advance_queue_after_completion(event: asyncio.Event):
"""Pop current head and wake next waiter, if any."""
# If this event is current head, wake next waiter
if self.queue and self.queue[0] is event:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to check if [0] is an event? Small little nit

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not checking if it is an event, but verifying that the event we pass in is in fact the current head of the queue.

self.queue.popleft()
if self.queue:
self.queue[0].set()
else:
# Else, remove it from the queue
try:
self.queue.remove(event)
except ValueError:
pass

async def cancel_api_call_if_disconnected() -> None:
await request.wait_for_disconnection()
log.debug(f"request with reqnum: {request_metrics.reqnum} was canceled")
log.debug(f"Request with reqnum: {request_metrics.reqnum} was canceled")
self.metrics._request_canceled(request_metrics)
raise asyncio.CancelledError
return

async def make_request() -> Union[web.Response, web.StreamResponse]:
try:
Expand All @@ -162,7 +179,9 @@ async def make_request() -> Union[web.Response, web.StreamResponse]:
res = await handler.generate_client_response(request, response)
self.metrics._request_success(request_metrics)
return res
except requests.exceptions.RequestException as e:
except asyncio.CancelledError:
raise
except Exception as e:
log.debug(f"[backend] Request error: {e}")
self.metrics._request_errored(request_metrics)
return web.Response(status=500)
Expand All @@ -177,46 +196,87 @@ async def make_request() -> Union[web.Response, web.StreamResponse]:
self.metrics._request_reject(request_metrics)
return web.Response(status=429)

acquired = False
disconnect_task = create_task(cancel_api_call_if_disconnected())
next_request_task = None
work_task = None
event = asyncio.Event() # Used in finally block, so initialize here

self.metrics._request_start(request_metrics)

try:
self.metrics._request_start(request_metrics)
if self.allow_parallel_requests is False:
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
await self.sem.acquire()
acquired = True
log.debug(
f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..."
)
else:
if self.allow_parallel_requests:
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
done, pending = await wait(
[
create_task(make_request()),
create_task(cancel_api_call_if_disconnected()),
],
return_when=FIRST_COMPLETED,
)
for t in pending:
t.cancel()
await asyncio.gather(*pending, return_exceptions=True)
work_task = create_task(make_request())
done, pending = await wait([work_task, disconnect_task], return_when=FIRST_COMPLETED)

done_task = done.pop()
try:
return done_task.result()
except Exception as e:
log.debug(f"Request task raised exception: {e}")
return web.Response(status=500)
for t in pending:
t.cancel()
await asyncio.gather(*pending, return_exceptions=True)

if disconnect_task in done:
return web.Response(status=499)

# otherwise work_task completed
return await work_task

# FIFO-queue branch
else:
# Insert a Event into the queue for this request
# Event.set() == our request is up next
self.queue.append(event)
if self.queue and self.queue[0] is event:
event.set()

# Race between our request being next and request being cancelled
next_request_task = create_task(event.wait())
first_done, first_pending = await wait(
[next_request_task, disconnect_task], return_when=FIRST_COMPLETED
)

# If the disconnect task wins the race
if disconnect_task in first_done:
# Clean up the next_request_task, then exit
for t in first_pending:
t.cancel()
await asyncio.gather(*first_pending, return_exceptions=True)
return web.Response(status=499)

# We are the next-up request in the queue
log.debug(f"Starting work on request {request_metrics.reqnum}...")

# Race the backend API call with the disconnect task
work_task = create_task(make_request())

done, pending = await wait([work_task, disconnect_task], return_when=FIRST_COMPLETED)
for t in pending:
t.cancel()
await asyncio.gather(*pending, return_exceptions=True)

if disconnect_task in done:
return web.Response(status=499)

# otherwise work_task completed
return await work_task

except asyncio.CancelledError:
# Client is gone. Do not write a response; just unwind.
return web.Response(status=499)
return web.Response(status=499)

except Exception as e:
log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500)

finally:
# Always release the semaphore if it was acquired
if acquired:
self.sem.release()
if not self.allow_parallel_requests:
advance_queue_after_completion(event)

self.metrics._request_end(request_metrics)
cleanup_tasks = [t for t in (next_request_task, work_task, disconnect_task) if t]
for t in cleanup_tasks:
if not t.done():
t.cancel()
if cleanup_tasks:
await asyncio.gather(*cleanup_tasks, return_exceptions=True)


@cached_property
def healthcheck_session(self):
Expand Down