Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions mrok/agent/sidecar/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ def __init__(
self,
target: str | Path | tuple[str, int],
*,
max_connections=1000,
max_keepalive_connections=10,
keepalive_expiry=120,
retries=0,
max_connections: int | None = 10,
max_keepalive_connections: int | None = None,
keepalive_expiry: float | None = None,
retries: int = 0,
):
self._target = target
self._target_type, self._target_address = self._parse_target()
Expand All @@ -34,10 +34,10 @@ def __init__(

def setup_connection_pool(
self,
max_connections: int | None = 1000,
max_keepalive_connections: int | None = 10,
keepalive_expiry: float | None = 120.0,
retries: int = 0,
max_connections: int | None,
max_keepalive_connections: int | None,
keepalive_expiry: float | None,
retries: int,
) -> AsyncConnectionPool:
if self._target_type == "unix":
return AsyncConnectionPool(
Expand Down
24 changes: 23 additions & 1 deletion mrok/agent/sidecar/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ def __init__(
identity_file: str,
target: str | Path | tuple[str, int],
workers: int = 4,
max_connections: int | None = 10,
max_keepalive_connections: int | None = None,
keepalive_expiry: float | None = None,
retries: int = 0,
publishers_port: int = 50000,
subscribers_port: int = 50001,
):
Expand All @@ -24,22 +28,40 @@ def __init__(
events_sub_port=subscribers_port,
)
self._target = target
self._max_connections = max_connections
self._max_keepalive_connections = max_keepalive_connections
self._keepalive_expiry = keepalive_expiry
self._retries = retries

def get_asgi_app(self):
return SidecarProxyApp(self._target)
return SidecarProxyApp(
self._target,
max_connections=self._max_connections,
max_keepalive_connections=self._max_keepalive_connections,
keepalive_expiry=self._keepalive_expiry,
retries=self._retries,
)


def run(
identity_file: str,
target_addr: str | Path | tuple[str, int],
workers: int = 4,
max_connections: int | None = 10,
max_keepalive_connections: int | None = None,
keepalive_expiry: float | None = None,
retries: int = 0,
publishers_port: int = 50000,
subscribers_port: int = 50001,
):
agent = SidecarAgent(
identity_file,
target_addr,
workers=workers,
max_connections=max_connections,
max_keepalive_connections=max_keepalive_connections,
keepalive_expiry=keepalive_expiry,
retries=retries,
publishers_port=publishers_port,
subscribers_port=subscribers_port,
)
Expand Down
51 changes: 50 additions & 1 deletion mrok/cli/commands/agent/run/sidecar.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,55 @@ def run_sidecar(
typer.Option(
"--workers",
"-w",
help=f"Number of workers. Default: {default_workers}",
help="Number of workers.",
show_default=True,
),
] = default_workers,
max_connections: Annotated[
int,
typer.Option(
"--max-pool-connections",
help=(
"The maximum number of concurrent HTTP connections that "
"the pool should allow. Any attempt to send a request on a pool that "
"would exceed this amount will block until a connection is available."
),
show_default=True,
),
] = 10,
max_keepalive_connections: Annotated[
int | None,
typer.Option(
"--max-pool-keepalive-connections",
help=(
"The maximum number of idle HTTP connections "
"that will be maintained in the pool."
),
show_default=True,
),
] = None,
keepalive_expiry: Annotated[
float | None,
typer.Option(
"--max-pool-keepalive-expiry",
help=(
"The duration in seconds that an idle HTTP connection "
"may be maintained for before being expired from the pool."
),
show_default=True,
),
] = None,
retries: Annotated[
int,
typer.Option(
"--max-pool-connect-retries",
help=(
"The duration in seconds that an idle HTTP connection "
"may be maintained for before being expired from the pool."
),
show_default=True,
),
] = 0,
publishers_port: Annotated[
int,
typer.Option(
Expand Down Expand Up @@ -65,6 +110,10 @@ def run_sidecar(
str(identity_file),
target_addr,
workers=workers,
max_connections=max_connections,
max_keepalive_connections=max_keepalive_connections,
keepalive_expiry=keepalive_expiry,
retries=retries,
publishers_port=publishers_port,
subscribers_port=subscribers_port,
)
44 changes: 43 additions & 1 deletion mrok/cli/commands/frontend/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,48 @@ def run_frontend(
show_default=True,
),
] = default_workers,
max_connections: Annotated[
int,
typer.Option(
"--max-pool-connections",
help=(
"The maximum number of concurrent HTTP connections that "
"the pool should allow. Any attempt to send a request on a pool that "
"would exceed this amount will block until a connection is available."
),
show_default=True,
),
] = 1000,
max_keepalive_connections: Annotated[
int | None,
typer.Option(
"--max-pool-keepalive-connections",
help=(
"The maximum number of idle HTTP connections "
"that will be maintained in the pool."
),
show_default=True,
),
] = 100,
keepalive_expiry: Annotated[
float | None,
typer.Option(
"--max-pool-keepalive-expiry",
help=(
"The duration in seconds that an idle HTTP connection "
"may be maintained for before being expired from the pool."
),
show_default=True,
),
] = 300,
):
"""Run the mrok frontend with Gunicorn and Uvicorn workers."""
frontend.run(identity_file, host, port, workers)
frontend.run(
identity_file,
host,
port,
workers,
max_connections=max_connections,
max_keepalive_connections=max_keepalive_connections,
keepalive_expiry=keepalive_expiry,
)
14 changes: 7 additions & 7 deletions mrok/frontend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ def __init__(
self,
identity_file: str,
*,
max_connections: int = 1000,
max_keepalive_connections: int = 10,
keepalive_expiry: float = 120.0,
max_connections: int | None = 10,
max_keepalive_connections: int | None = None,
keepalive_expiry: float | None = None,
retries=0,
):
self._identity_file = identity_file
Expand All @@ -32,10 +32,10 @@ def __init__(

def setup_connection_pool(
self,
max_connections: int | None = 1000,
max_keepalive_connections: int | None = 100,
keepalive_expiry: float | None = 120.0,
retries: int = 0,
max_connections: int | None,
max_keepalive_connections: int | None,
keepalive_expiry: float | None,
retries: int,
) -> AsyncConnectionPool:
return AsyncConnectionPool(
max_connections=max_connections,
Expand Down
10 changes: 9 additions & 1 deletion mrok/frontend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,16 @@ def run(
host: str,
port: int,
workers: int,
max_connections: int | None,
max_keepalive_connections: int | None,
keepalive_expiry: float | None,
):
app = FrontendProxyApp(str(identity_file))
app = FrontendProxyApp(
str(identity_file),
max_connections=max_connections,
max_keepalive_connections=max_keepalive_connections,
keepalive_expiry=keepalive_expiry,
)

options = {
"bind": f"{host}:{port}",
Expand Down
15 changes: 8 additions & 7 deletions mrok/proxy/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ class ProxyAppBase(abc.ABC):
def __init__(
self,
*,
max_connections: int | None = 1000,
max_keepalive_connections: int | None = 10,
keepalive_expiry: float | None = 120.0,
max_connections: int | None = 10,
max_keepalive_connections: int | None = None,
keepalive_expiry: float | None = None,
retries: int = 0,
) -> None:
self._pool = self.setup_connection_pool(
Expand All @@ -41,10 +41,10 @@ def __init__(
@abc.abstractmethod
def setup_connection_pool(
self,
max_connections: int | None = 1000,
max_keepalive_connections: int | None = 10,
keepalive_expiry: float | None = 120.0,
retries: int = 0,
max_connections: int | None,
max_keepalive_connections: int | None,
keepalive_expiry: float | None,
retries: int,
) -> AsyncConnectionPool:
raise NotImplementedError()

Expand Down Expand Up @@ -78,6 +78,7 @@ async def __call__(self, scope: Scope, receive: ASGIReceive, send: ASGISend) ->
content=body_stream,
)
response = await self._pool.handle_async_request(request)
logger.debug(f"connection pool status: {self._pool}")
response_headers = []
for k, v in response.headers:
if k.lower() not in HOP_BY_HOP_HEADERS:
Expand Down
23 changes: 23 additions & 0 deletions mrok/proxy/streams.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
import asyncio
import select
import sys
from typing import Any

from httpcore import AsyncNetworkStream

from mrok.proxy.types import ASGIReceive


def is_readable(sock): # pragma: no cover
# Stolen from
# https://github.com/python-trio/trio/blob/20ee2b1b7376db637435d80e266212a35837ddcc/trio/_socket.py#L471C1-L478C31

# use select.select on Windows, and select.poll everywhere else
if sys.platform == "win32":
rready, _, _ = select.select([sock], [], [], 0)
return bool(rready)
p = select.poll()
p.register(sock, select.POLLIN)
return bool(p.poll(0))


class AIONetworkStream(AsyncNetworkStream):
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
self._reader = reader
Expand All @@ -21,6 +37,13 @@ async def aclose(self) -> None:
self._writer.close()
await self._writer.wait_closed()

def get_extra_info(self, info: str) -> Any:
transport = self._writer.transport
if info == "is_readable":
sock = transport.get_extra_info("socket")
return is_readable(sock)
return transport.get_extra_info(info)


class ASGIRequestBodyStream:
def __init__(self, receive: ASGIReceive):
Expand Down
13 changes: 5 additions & 8 deletions mrok/proxy/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,26 +39,23 @@ def __init__(
self._events_enabled = events_enabled
self._metrics_interval = metrics_interval
self.events_publisher_port = events_publisher_port
self._zmq_ctx = None
self._events_publisher = None
self._metrics_collector = WorkerMetricsCollector(self._worker_id)
self._zmq_ctx = zmq.asyncio.Context()
self._events_publisher = self._zmq_ctx.socket(zmq.PUB)
self._events_publish_task = None
self._metrics_collector = None

async def on_startup(self):
logger.info(f"Start events publishing for worker {self._worker_id}")
self._zmq_ctx = zmq.asyncio.Context()
self._events_publisher = self._zmq_ctx.socket(zmq.PUB)
self._events_publisher.connect(f"tcp://localhost:{self.events_publisher_port}")
self._metrics_collector = WorkerMetricsCollector(self._worker_id)
self._events_publish_task = asyncio.create_task(self.publish_metrics_event())
logger.info(f"Events publishing for worker {self._worker_id} started")

async def on_shutdown(self):
logger.info(f"Stop events publishing for worker {self._worker_id}")
self._events_publish_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._events_publish_task
self._events_publisher.close()
self._zmq_ctx.term()
logger.info(f"Events publishing for worker {self._worker_id} stopped")

@asynccontextmanager
async def lifespan(self, app: ASGIApp):
Expand Down
16 changes: 15 additions & 1 deletion tests/agent/sidecar/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@ def test_sidecar_agent(mocker: MockerFixture):
)
assert agent.reload is False
assert agent.get_asgi_app() == mocked_app
mocked_app_ctor.assert_called_once_with(":8000")
mocked_app_ctor.assert_called_once_with(
":8000",
max_connections=10,
max_keepalive_connections=None,
keepalive_expiry=None,
retries=0,
)


def test_run(mocker: MockerFixture):
Expand All @@ -32,6 +38,10 @@ def test_run(mocker: MockerFixture):
"ziti-identity.json",
"target-addr",
workers=10,
max_connections=15,
max_keepalive_connections=3,
keepalive_expiry=100,
retries=6,
publishers_port=4000,
subscribers_port=5000,
)
Expand All @@ -40,6 +50,10 @@ def test_run(mocker: MockerFixture):
"ziti-identity.json",
"target-addr",
workers=10,
max_connections=15,
max_keepalive_connections=3,
retries=6,
keepalive_expiry=100,
publishers_port=4000,
subscribers_port=5000,
)
Expand Down
Loading
Loading