From 4f3dc4191449d48f60d2a39c8df1e742b5e1753f Mon Sep 17 00:00:00 2001 From: Aleksandra Ovchinnikova Date: Mon, 29 Dec 2025 19:01:36 -0500 Subject: [PATCH] MPT-16669 Improve mrok proxy module and fix identity name --- mrok/agent/devtools/inspector/app.py | 4 +- mrok/agent/sidecar/app.py | 2 +- mrok/agent/sidecar/main.py | 4 + mrok/agent/ziticorn.py | 2 +- mrok/cli/commands/admin/bootstrap.py | 5 +- mrok/cli/commands/admin/utils.py | 4 +- mrok/cli/commands/agent/run/sidecar.py | 9 + mrok/controller/schemas.py | 4 +- mrok/frontend/app.py | 2 +- mrok/proxy/app.py | 4 +- mrok/proxy/asgi.py | 8 +- mrok/proxy/backend.py | 2 +- mrok/proxy/event_publisher.py | 66 ++++++ mrok/proxy/lifespan.py | 10 - mrok/proxy/master.py | 4 +- mrok/proxy/metrics.py | 4 +- mrok/proxy/{middlewares.py => middleware.py} | 10 +- mrok/proxy/{datastructures.py => models.py} | 16 +- mrok/proxy/protocol.py | 11 - mrok/proxy/server.py | 14 -- mrok/proxy/{streams.py => stream.py} | 2 +- mrok/proxy/worker.py | 74 ++----- mrok/proxy/{config.py => ziticorn.py} | 35 +++- mrok/types/__init__.py | 0 mrok/{proxy/types.py => types/proxy.py} | 2 +- mrok/types/ziti.py | 1 + mrok/ziti/api.py | 33 ++- mrok/ziti/bootstrap.py | 5 +- mrok/ziti/identities.py | 9 +- mrok/ziti/services.py | 5 +- tests/agent/sidecar/test_main.py | 1 + tests/cli/agent/test_run.py | 1 + tests/conftest.py | 2 +- tests/proxy/test_app.py | 2 +- tests/proxy/test_asgi.py | 2 +- tests/proxy/test_config.py | 22 +- tests/proxy/test_event_publisher.py | 190 +++++++++++++++++ tests/proxy/test_lifespan.py | 4 +- tests/proxy/test_master.py | 2 +- tests/proxy/test_metrics.py | 5 +- ...test_middlewares.py => test_middleware.py} | 8 +- tests/proxy/test_protocol.py | 7 +- tests/proxy/test_server.py | 27 +-- .../proxy/{test_streams.py => test_stream.py} | 4 +- tests/proxy/test_worker.py | 192 +----------------- tests/types.py | 2 +- tests/ziti/test_api.py | 4 +- tests/ziti/test_identities.py | 64 +++--- tests/ziti/test_services.py | 12 +- 49 files changed, 469 insertions(+), 433 deletions(-) create mode 100644 mrok/proxy/event_publisher.py delete mode 100644 mrok/proxy/lifespan.py rename mrok/proxy/{middlewares.py => middleware.py} (94%) rename mrok/proxy/{datastructures.py => models.py} (94%) delete mode 100644 mrok/proxy/protocol.py delete mode 100644 mrok/proxy/server.py rename mrok/proxy/{streams.py => stream.py} (98%) rename mrok/proxy/{config.py => ziticorn.py} (62%) create mode 100644 mrok/types/__init__.py rename mrok/{proxy/types.py => types/proxy.py} (93%) create mode 100644 mrok/types/ziti.py create mode 100644 tests/proxy/test_event_publisher.py rename tests/proxy/{test_middlewares.py => test_middleware.py} (95%) rename tests/proxy/{test_streams.py => test_stream.py} (96%) diff --git a/mrok/agent/devtools/inspector/app.py b/mrok/agent/devtools/inspector/app.py index 486aa0b..2e87e0b 100755 --- a/mrok/agent/devtools/inspector/app.py +++ b/mrok/agent/devtools/inspector/app.py @@ -27,7 +27,7 @@ from textual.worker import get_current_worker from mrok import __version__ -from mrok.proxy.datastructures import Event, HTTPHeaders, HTTPResponse, WorkerMetrics, ZitiMrokMeta +from mrok.proxy.models import Event, HTTPHeaders, HTTPResponse, ServiceMetadata, WorkerMetrics def build_tree(node, data): @@ -185,7 +185,7 @@ def on_mount(self) -> None: # mem=int(mean([m.process.mem for m in self.workers_metrics.values()])), # ) - def update_meta(self, meta: ZitiMrokMeta) -> None: + def update_meta(self, meta: ServiceMetadata) -> None: table = self.query_one(DataTable) if len(table.rows) == 0: table.add_row("URL", f"https://{meta.extension}.{meta.domain}") diff --git a/mrok/agent/sidecar/app.py b/mrok/agent/sidecar/app.py index 1d63af4..f09d3da 100644 --- a/mrok/agent/sidecar/app.py +++ b/mrok/agent/sidecar/app.py @@ -5,7 +5,7 @@ from httpcore import AsyncConnectionPool from mrok.proxy.app import ProxyAppBase -from mrok.proxy.types import Scope +from mrok.types.proxy import Scope logger = logging.getLogger("mrok.agent") diff --git a/mrok/agent/sidecar/main.py b/mrok/agent/sidecar/main.py index b08c731..6731215 100644 --- a/mrok/agent/sidecar/main.py +++ b/mrok/agent/sidecar/main.py @@ -13,6 +13,7 @@ def __init__( identity_file: str, target: str | Path | tuple[str, int], workers: int = 4, + events_enabled: bool = True, max_connections: int | None = 10, max_keepalive_connections: int | None = None, keepalive_expiry: float | None = None, @@ -24,6 +25,7 @@ def __init__( identity_file, workers=workers, reload=False, + events_enabled=events_enabled, events_pub_port=publishers_port, events_sub_port=subscribers_port, ) @@ -47,6 +49,7 @@ def run( identity_file: str, target_addr: str | Path | tuple[str, int], workers: int = 4, + events_enabled: bool = True, max_connections: int | None = 10, max_keepalive_connections: int | None = None, keepalive_expiry: float | None = None, @@ -58,6 +61,7 @@ def run( identity_file, target_addr, workers=workers, + events_enabled=events_enabled, max_connections=max_connections, max_keepalive_connections=max_keepalive_connections, keepalive_expiry=keepalive_expiry, diff --git a/mrok/agent/ziticorn.py b/mrok/agent/ziticorn.py index f4c61dd..ff15f83 100644 --- a/mrok/agent/ziticorn.py +++ b/mrok/agent/ziticorn.py @@ -1,5 +1,5 @@ from mrok.proxy.master import MasterBase -from mrok.proxy.types import ASGIApp +from mrok.types.proxy import ASGIApp class ZiticornAgent(MasterBase): diff --git a/mrok/cli/commands/admin/bootstrap.py b/mrok/cli/commands/admin/bootstrap.py index b5c2951..4273ed4 100644 --- a/mrok/cli/commands/admin/bootstrap.py +++ b/mrok/cli/commands/admin/bootstrap.py @@ -8,14 +8,15 @@ from mrok.cli.commands.admin.utils import parse_tags from mrok.conf import Settings -from mrok.ziti.api import TagsType, ZitiClientAPI, ZitiManagementAPI +from mrok.types.ziti import Tags +from mrok.ziti.api import ZitiClientAPI, ZitiManagementAPI from mrok.ziti.bootstrap import bootstrap_identity logger = logging.getLogger(__name__) async def bootstrap( - settings: Settings, forced: bool, tags: TagsType | None + settings: Settings, forced: bool, tags: Tags | None ) -> tuple[str, dict[str, Any] | None]: async with ZitiManagementAPI(settings) as mgmt_api, ZitiClientAPI(settings) as client_api: return await bootstrap_identity( diff --git a/mrok/cli/commands/admin/utils.py b/mrok/cli/commands/admin/utils.py index 4156bd1..4650113 100644 --- a/mrok/cli/commands/admin/utils.py +++ b/mrok/cli/commands/admin/utils.py @@ -2,10 +2,10 @@ import typer -from mrok.ziti.api import TagsType +from mrok.types.ziti import Tags -def parse_tags(pairs: list[str] | None) -> TagsType | None: +def parse_tags(pairs: list[str] | None) -> Tags | None: if not pairs: return None diff --git a/mrok/cli/commands/agent/run/sidecar.py b/mrok/cli/commands/agent/run/sidecar.py index d553330..e9531ed 100644 --- a/mrok/cli/commands/agent/run/sidecar.py +++ b/mrok/cli/commands/agent/run/sidecar.py @@ -98,6 +98,14 @@ def run_sidecar( show_default=True, ), ] = 50001, + no_events: Annotated[ + bool, + typer.Option( + "--no-events", + help="Disable events. Default: False", + show_default=True, + ), + ] = False, ): """Run a Sidecar Proxy to expose a web application through OpenZiti.""" if ":" in str(target): @@ -110,6 +118,7 @@ def run_sidecar( str(identity_file), target_addr, workers=workers, + events_enabled=not no_events, max_connections=max_connections, max_keepalive_connections=max_keepalive_connections, keepalive_expiry=keepalive_expiry, diff --git a/mrok/controller/schemas.py b/mrok/controller/schemas.py index c87f548..e9573de 100644 --- a/mrok/controller/schemas.py +++ b/mrok/controller/schemas.py @@ -9,12 +9,12 @@ computed_field, ) -from mrok.ziti.api import TagsType +from mrok.types.ziti import Tags class BaseSchema(BaseModel): model_config = ConfigDict(from_attributes=True, extra="ignore") - tags: TagsType | None = None + tags: Tags | None = None class IdSchema(BaseModel): diff --git a/mrok/frontend/app.py b/mrok/frontend/app.py index 7c698c5..c081d44 100644 --- a/mrok/frontend/app.py +++ b/mrok/frontend/app.py @@ -6,7 +6,7 @@ from mrok.proxy.app import ProxyAppBase from mrok.proxy.backend import AIOZitiNetworkBackend from mrok.proxy.exceptions import InvalidTargetError -from mrok.proxy.types import Scope +from mrok.types.proxy import Scope RE_SUBDOMAIN = re.compile(r"(?i)^(?:EXT-\d{4}-\d{4}|INS-\d{4}-\d{4}-\d{4})$") diff --git a/mrok/proxy/app.py b/mrok/proxy/app.py index 6b083af..fce4e8f 100644 --- a/mrok/proxy/app.py +++ b/mrok/proxy/app.py @@ -4,8 +4,8 @@ from httpcore import AsyncConnectionPool, Request from mrok.proxy.exceptions import ProxyError -from mrok.proxy.streams import ASGIRequestBodyStream -from mrok.proxy.types import ASGIReceive, ASGISend, Scope +from mrok.proxy.stream import ASGIRequestBodyStream +from mrok.types.proxy import ASGIReceive, ASGISend, Scope logger = logging.getLogger("mrok.proxy") diff --git a/mrok/proxy/asgi.py b/mrok/proxy/asgi.py index ef607b4..d7aff67 100644 --- a/mrok/proxy/asgi.py +++ b/mrok/proxy/asgi.py @@ -2,7 +2,7 @@ from contextlib import AsyncExitStack, asynccontextmanager from typing import Any, ParamSpec, Protocol -from mrok.proxy.types import ASGIApp, ASGIReceive, ASGISend, Lifespan, Scope +from mrok.types.proxy import ASGIApp, ASGIReceive, ASGISend, Lifespan, Scope P = ParamSpec("P") @@ -57,9 +57,9 @@ async def merge_lifespan(self, app: ASGIApp): if self.lifespan is not None: outer_state = await stack.enter_async_context(self.lifespan(app)) state.update(outer_state or {}) - starlette_lifesapn = self.get_starlette_lifespan() - if starlette_lifesapn is not None: - inner_state = await stack.enter_async_context(starlette_lifesapn(app)) + starlette_lifespan = self.get_starlette_lifespan() + if starlette_lifespan is not None: + inner_state = await stack.enter_async_context(starlette_lifespan(app)) state.update(inner_state or {}) yield state diff --git a/mrok/proxy/backend.py b/mrok/proxy/backend.py index 190e8c1..f9fe50c 100644 --- a/mrok/proxy/backend.py +++ b/mrok/proxy/backend.py @@ -7,7 +7,7 @@ from openziti.context import ZitiContext from mrok.proxy.exceptions import InvalidTargetError, TargetUnavailableError -from mrok.proxy.streams import AIONetworkStream +from mrok.proxy.stream import AIONetworkStream class AIOZitiNetworkBackend(AsyncNetworkBackend): diff --git a/mrok/proxy/event_publisher.py b/mrok/proxy/event_publisher.py new file mode 100644 index 0000000..9c1289a --- /dev/null +++ b/mrok/proxy/event_publisher.py @@ -0,0 +1,66 @@ +import asyncio +import contextlib +import logging + +import zmq +import zmq.asyncio + +from mrok.proxy.asgi import ASGIAppWrapper +from mrok.proxy.metrics import MetricsCollector +from mrok.proxy.middleware import CaptureMiddleware, MetricsMiddleware +from mrok.proxy.models import Event, HTTPResponse, ServiceMetadata, Status +from mrok.types.proxy import ASGIApp + +logger = logging.getLogger("mrok.proxy") + + +class EventPublisher: + def __init__( + self, + worker_id: str, + meta: ServiceMetadata | None = None, + event_publisher_port: int = 50000, + metrics_interval: float = 5.0, + ): + self._worker_id = worker_id + self._meta = meta + self._metrics_interval = metrics_interval + self.publisher_port = event_publisher_port + self._zmq_ctx = zmq.asyncio.Context() + self._publisher = self._zmq_ctx.socket(zmq.PUB) + self._metrics_collector = MetricsCollector(self._worker_id) + self._publish_task = None + + async def on_startup(self): + self._publisher.connect(f"tcp://localhost:{self.publisher_port}") + self._publish_task = asyncio.create_task(self.publish_metrics_event()) + logger.info(f"Events publishing for worker {self._worker_id} started") + + async def on_shutdown(self): + self._publish_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._publish_task + self._publisher.close() + self._zmq_ctx.term() + logger.info(f"Events publishing for worker {self._worker_id} stopped") + + async def publish_metrics_event(self): + while True: + snap = await self._metrics_collector.snapshot() + event = Event(type="status", data=Status(meta=self._meta, metrics=snap)) + await self._publisher.send_string(event.model_dump_json()) + await asyncio.sleep(self._metrics_interval) + + async def publish_response_event(self, response: HTTPResponse): + event = Event(type="response", data=response) + await self._publisher.send_string(event.model_dump_json()) # type: ignore[attr-defined] + + def setup_middleware(self, app: ASGIAppWrapper): + app.add_middleware(CaptureMiddleware, self.publish_response_event) + app.add_middleware(MetricsMiddleware, self._metrics_collector) # type: ignore + + @contextlib.asynccontextmanager + async def lifespan(self, app: ASGIApp): + await self.on_startup() # type: ignore + yield + await self.on_shutdown() # type: ignore diff --git a/mrok/proxy/lifespan.py b/mrok/proxy/lifespan.py deleted file mode 100644 index 4f2b9ff..0000000 --- a/mrok/proxy/lifespan.py +++ /dev/null @@ -1,10 +0,0 @@ -import logging - -from uvicorn.config import Config -from uvicorn.lifespan.on import LifespanOn - - -class MrokLifespan(LifespanOn): - def __init__(self, config: Config) -> None: - super().__init__(config) - self.logger = logging.getLogger("mrok.proxy") diff --git a/mrok/proxy/master.py b/mrok/proxy/master.py index 3485517..b726b3f 100644 --- a/mrok/proxy/master.py +++ b/mrok/proxy/master.py @@ -14,8 +14,8 @@ from mrok.conf import get_settings from mrok.logging import setup_logging -from mrok.proxy.types import ASGIApp from mrok.proxy.worker import Worker +from mrok.types.proxy import ASGIApp logger = logging.getLogger("mrok.agent") @@ -67,7 +67,7 @@ def start_uvicorn_worker( app, identity_file, events_enabled=events_enabled, - events_publisher_port=events_pub_port, + event_publisher_port=events_pub_port, metrics_interval=metrics_interval, ) worker.run() diff --git a/mrok/proxy/metrics.py b/mrok/proxy/metrics.py index 4e9d307..ce1cb37 100644 --- a/mrok/proxy/metrics.py +++ b/mrok/proxy/metrics.py @@ -8,7 +8,7 @@ import psutil from hdrh.histogram import HdrHistogram -from mrok.proxy.datastructures import ( +from mrok.proxy.models import ( DataTransferMetrics, ProcessMetrics, RequestsMetrics, @@ -49,7 +49,7 @@ async def get_process_metrics(interval: float = 0.1) -> ProcessMetrics: return await asyncio.to_thread(_collect_process_usage, interval) -class WorkerMetricsCollector: +class MetricsCollector: def __init__(self, worker_id: str, lowest=1, highest=60000, sigfigs=3): self.worker_id = worker_id self.total_requests = 0 diff --git a/mrok/proxy/middlewares.py b/mrok/proxy/middleware.py similarity index 94% rename from mrok/proxy/middlewares.py rename to mrok/proxy/middleware.py index 370fd02..bcda3bb 100644 --- a/mrok/proxy/middlewares.py +++ b/mrok/proxy/middleware.py @@ -3,9 +3,10 @@ import time from mrok.proxy.constants import MAX_REQUEST_BODY_BYTES, MAX_RESPONSE_BODY_BYTES -from mrok.proxy.datastructures import FixedSizeByteBuffer, HTTPHeaders, HTTPRequest, HTTPResponse -from mrok.proxy.metrics import WorkerMetricsCollector -from mrok.proxy.types import ( +from mrok.proxy.metrics import MetricsCollector +from mrok.proxy.models import FixedSizeByteBuffer, HTTPHeaders, HTTPRequest, HTTPResponse +from mrok.proxy.utils import must_capture_request, must_capture_response +from mrok.types.proxy import ( ASGIApp, ASGIReceive, ASGISend, @@ -13,7 +14,6 @@ ResponseCompleteCallback, Scope, ) -from mrok.proxy.utils import must_capture_request, must_capture_response logger = logging.getLogger("mrok.proxy") @@ -98,7 +98,7 @@ async def send_wrapper(msg: Message): class MetricsMiddleware: - def __init__(self, app: ASGIApp, metrics: WorkerMetricsCollector): + def __init__(self, app: ASGIApp, metrics: MetricsCollector): self.app = app self.metrics = metrics diff --git a/mrok/proxy/datastructures.py b/mrok/proxy/models.py similarity index 94% rename from mrok/proxy/datastructures.py rename to mrok/proxy/models.py index 157a956..238362a 100644 --- a/mrok/proxy/datastructures.py +++ b/mrok/proxy/models.py @@ -8,7 +8,7 @@ from pydantic_core import core_schema -class ZitiId(BaseModel): +class X509Credentials(BaseModel): key: str cert: str ca: str @@ -21,7 +21,7 @@ def strip_pem_prefix(cls, value: str) -> str: return value -class ZitiMrokMeta(BaseModel): +class ServiceMetadata(BaseModel): model_config = ConfigDict(extra="ignore") identity: str extension: str @@ -30,21 +30,21 @@ class ZitiMrokMeta(BaseModel): tags: dict[str, str | bool | None] | None = None -class ZitiIdentity(BaseModel): +class Identity(BaseModel): model_config = ConfigDict(extra="ignore") zt_api: str = Field(validation_alias="ztAPI") - id: ZitiId + id: X509Credentials zt_apis: str | None = Field(default=None, validation_alias="ztAPIs") config_types: str | None = Field(default=None, validation_alias="configTypes") enable_ha: bool = Field(default=False, validation_alias="enableHa") - mrok: ZitiMrokMeta | None = None + mrok: ServiceMetadata | None = None @staticmethod - def load_from_file(path: str | Path) -> ZitiIdentity: + def load_from_file(path: str | Path) -> Identity: path = Path(path) with path.open("r", encoding="utf-8") as f: data = json.load(f) - return ZitiIdentity.model_validate(data) + return Identity.model_validate(data) class FixedSizeByteBuffer: @@ -183,7 +183,7 @@ class WorkerMetrics(BaseModel): class Status(BaseModel): type: Literal["status"] = "status" - meta: ZitiMrokMeta + meta: ServiceMetadata metrics: WorkerMetrics diff --git a/mrok/proxy/protocol.py b/mrok/proxy/protocol.py deleted file mode 100644 index 991c811..0000000 --- a/mrok/proxy/protocol.py +++ /dev/null @@ -1,11 +0,0 @@ -import logging - -from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol - - -class MrokHttpToolsProtocol(HttpToolsProtocol): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.logger = logging.getLogger("mrok.proxy") - self.access_logger = logging.getLogger("mrok.access") - self.access_log = self.access_logger.hasHandlers() diff --git a/mrok/proxy/server.py b/mrok/proxy/server.py deleted file mode 100644 index 7fbe8ae..0000000 --- a/mrok/proxy/server.py +++ /dev/null @@ -1,14 +0,0 @@ -import logging -import socket - -from uvicorn import server - -server.logger = logging.getLogger("mrok.proxy") - - -class MrokServer(server.Server): - async def serve(self, sockets: list[socket.socket] | None = None) -> None: - if not sockets: - sockets = [self.config.bind_socket()] - with self.capture_signals(): - await self._serve(sockets) diff --git a/mrok/proxy/streams.py b/mrok/proxy/stream.py similarity index 98% rename from mrok/proxy/streams.py rename to mrok/proxy/stream.py index 023b976..2efd83e 100644 --- a/mrok/proxy/streams.py +++ b/mrok/proxy/stream.py @@ -5,7 +5,7 @@ from httpcore import AsyncNetworkStream -from mrok.proxy.types import ASGIReceive +from mrok.types.proxy import ASGIReceive def is_readable(sock): # pragma: no cover diff --git a/mrok/proxy/worker.py b/mrok/proxy/worker.py index 7e63278..cb1adda 100644 --- a/mrok/proxy/worker.py +++ b/mrok/proxy/worker.py @@ -1,22 +1,17 @@ import asyncio import contextlib import logging -from contextlib import asynccontextmanager from pathlib import Path -import zmq -import zmq.asyncio from uvicorn.importer import import_from_string from mrok.conf import get_settings from mrok.logging import setup_logging from mrok.proxy.asgi import ASGIAppWrapper -from mrok.proxy.config import MrokBackendConfig -from mrok.proxy.datastructures import Event, HTTPResponse, Status, ZitiIdentity -from mrok.proxy.metrics import WorkerMetricsCollector -from mrok.proxy.middlewares import CaptureMiddleware, MetricsMiddleware -from mrok.proxy.server import MrokServer -from mrok.proxy.types import ASGIApp +from mrok.proxy.event_publisher import EventPublisher +from mrok.proxy.models import Identity +from mrok.proxy.ziticorn import BackendConfig, Server +from mrok.types.proxy import ASGIApp logger = logging.getLogger("mrok.proxy") @@ -29,68 +24,41 @@ def __init__( identity_file: str | Path, *, events_enabled: bool = True, - events_publisher_port: int = 50000, + event_publisher_port: int = 50000, metrics_interval: float = 5.0, ): self._worker_id = worker_id self._identity_file = identity_file - self._identity = ZitiIdentity.load_from_file(self._identity_file) + self._identity = Identity.load_from_file(self._identity_file) self._app = app - self._events_enabled = events_enabled - self._metrics_interval = metrics_interval - self.events_publisher_port = events_publisher_port - self._metrics_collector = WorkerMetricsCollector(self._worker_id) - self._zmq_ctx = zmq.asyncio.Context() - self._events_publisher = self._zmq_ctx.socket(zmq.PUB) - self._events_publish_task = None - - async def on_startup(self): - self._events_publisher.connect(f"tcp://localhost:{self.events_publisher_port}") - self._events_publish_task = asyncio.create_task(self.publish_metrics_event()) - logger.info(f"Events publishing for worker {self._worker_id} started") - - async def on_shutdown(self): - self._events_publish_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await self._events_publish_task - self._events_publisher.close() - self._zmq_ctx.term() - logger.info(f"Events publishing for worker {self._worker_id} stopped") - - @asynccontextmanager - async def lifespan(self, app: ASGIApp): - await self.on_startup() - yield - await self.on_shutdown() - async def publish_metrics_event(self): - while True: - snap = await self._metrics_collector.snapshot() - event = Event(type="status", data=Status(meta=self._identity.mrok, metrics=snap)) - await self._events_publisher.send_string(event.model_dump_json()) - await asyncio.sleep(self._metrics_interval) - - async def publish_response_event(self, response: HTTPResponse): - event = Event(type="response", data=response) - await self._events_publisher.send_string(event.model_dump_json()) # type: ignore[attr-defined] + self._events_enabled = events_enabled + self._event_publisher = ( + EventPublisher( + worker_id=worker_id, + meta=self._identity.mrok, + event_publisher_port=event_publisher_port, + metrics_interval=metrics_interval, + ) + if events_enabled + else None + ) def setup_app(self): app = ASGIAppWrapper( self._app if not isinstance(self._app, str) else import_from_string(self._app), - lifespan=self.lifespan if self._events_enabled else None, + lifespan=self._event_publisher.lifespan if self._events_enabled else None, ) if self._events_enabled: - app.add_middleware(CaptureMiddleware, self.publish_response_event) - app.add_middleware(MetricsMiddleware, self._metrics_collector) - + self._event_publisher.setup_middleware(app) return app def run(self): setup_logging(get_settings()) app = self.setup_app() - config = MrokBackendConfig(app, self._identity_file) - server = MrokServer(config) + config = BackendConfig(app, self._identity_file) + server = Server(config) with contextlib.suppress(KeyboardInterrupt, asyncio.CancelledError): server.run() diff --git a/mrok/proxy/config.py b/mrok/proxy/ziticorn.py similarity index 62% rename from mrok/proxy/config.py rename to mrok/proxy/ziticorn.py index e8ec602..fb8f704 100644 --- a/mrok/proxy/config.py +++ b/mrok/proxy/ziticorn.py @@ -6,17 +6,40 @@ from typing import Any import openziti -from uvicorn import config +from uvicorn import config, server +from uvicorn.lifespan.on import LifespanOn +from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol as UvHttpToolsProtocol -from mrok.proxy.protocol import MrokHttpToolsProtocol -from mrok.proxy.types import ASGIApp +from mrok.types.proxy import ASGIApp logger = logging.getLogger("mrok.proxy") -config.LIFESPAN["auto"] = "mrok.proxy.lifespan:MrokLifespan" +config.LIFESPAN["auto"] = "mrok.proxy.ziticorn:Lifespan" -class MrokBackendConfig(config.Config): +class Lifespan(LifespanOn): + def __init__(self, lf_config: config.Config) -> None: + super().__init__(lf_config) + self.logger = logging.getLogger("mrok.proxy") + + +class HttpToolsProtocol(UvHttpToolsProtocol): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.logger = logging.getLogger("mrok.proxy") + self.access_logger = logging.getLogger("mrok.access") + self.access_log = self.access_logger.hasHandlers() + + +class Server(server.Server): + async def serve(self, sockets: list[socket.socket] | None = None) -> None: + if not sockets: + sockets = [self.config.bind_socket()] + with self.capture_signals(): + await self._serve(sockets) + + +class BackendConfig(config.Config): def __init__( self, app: ASGIApp | Callable[..., Any] | str, @@ -32,7 +55,7 @@ def __init__( super().__init__( app, loop="asyncio", - http=MrokHttpToolsProtocol, + http=HttpToolsProtocol, backlog=backlog, ) diff --git a/mrok/types/__init__.py b/mrok/types/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mrok/proxy/types.py b/mrok/types/proxy.py similarity index 93% rename from mrok/proxy/types.py rename to mrok/types/proxy.py index a8885ce..c8f94e0 100644 --- a/mrok/proxy/types.py +++ b/mrok/types/proxy.py @@ -4,7 +4,7 @@ from contextlib import AbstractAsyncContextManager from typing import Any, Never -from mrok.proxy.datastructures import HTTPResponse +from mrok.proxy.models import HTTPResponse Scope = MutableMapping[str, Any] Message = MutableMapping[str, Any] diff --git a/mrok/types/ziti.py b/mrok/types/ziti.py new file mode 100644 index 0000000..272c70f --- /dev/null +++ b/mrok/types/ziti.py @@ -0,0 +1 @@ +Tags = dict[str, str | bool | None] diff --git a/mrok/ziti/api.py b/mrok/ziti/api.py index 0c21ea9..05f1f2f 100644 --- a/mrok/ziti/api.py +++ b/mrok/ziti/api.py @@ -11,12 +11,11 @@ import httpx from mrok.conf import Settings +from mrok.types.ziti import Tags from mrok.ziti.constants import MROK_VERSION_TAG, MROK_VERSION_TAG_NAME logger = logging.getLogger(__name__) -TagsType = dict[str, str | bool | None] - class ZitiAPIError(Exception): pass @@ -70,7 +69,7 @@ def httpx_client(self) -> httpx.AsyncClient: ), ) - async def create(self, endpoint: str, payload: dict[str, Any], tags: TagsType | None) -> str: + async def create(self, endpoint: str, payload: dict[str, Any], tags: Tags | None) -> str: payload["tags"] = self._merge_tags(tags) response: httpx.Response = await self.httpx_client.post( endpoint, @@ -156,8 +155,8 @@ async def __aexit__( ) -> None: return await self.httpx_client.__aexit__(exc_type, exc_val, exc_tb) - def _merge_tags(self, tags: TagsType | None) -> TagsType: - prepared_tags: TagsType = tags or {} + def _merge_tags(self, tags: Tags | None) -> Tags: + prepared_tags: Tags = tags or {} prepared_tags.update(MROK_VERSION_TAG) return prepared_tags @@ -281,9 +280,7 @@ def identities( async def search_config(self, id_or_name) -> dict[str, Any] | None: return await self.search_by_id_or_name("/configs", id_or_name) - async def create_config( - self, name: str, config_type_id: str, tags: TagsType | None = None - ) -> str: + async def create_config(self, name: str, config_type_id: str, tags: Tags | None = None) -> str: return await self.create( "/configs", { @@ -302,7 +299,7 @@ async def create_config( async def delete_config(self, config_id: str) -> None: return await self.delete("/configs", config_id) - async def create_config_type(self, name: str, tags: TagsType | None = None) -> str: + async def create_config_type(self, name: str, tags: Tags | None = None) -> str: return await self.create( "/config-types", { @@ -316,7 +313,7 @@ async def create_service( self, name: str, config_id: str, - tags: TagsType | None = None, + tags: Tags | None = None, ) -> str: return await self.create( "/services", @@ -332,7 +329,7 @@ async def create_service_router_policy( self, name: str, service_id: str, - tags: TagsType | None = None, + tags: Tags | None = None, ) -> str: return await self.create( "/service-edge-router-policies", @@ -351,7 +348,7 @@ async def create_router_policy( self, name: str, identity_id: str, - tags: TagsType | None = None, + tags: Tags | None = None, ) -> str: return await self.create( "/edge-router-policies", @@ -385,10 +382,10 @@ async def get_service(self, service_id: str) -> dict[str, Any]: async def delete_service(self, service_id: str) -> None: return await self.delete("/services", service_id) - async def create_user_identity(self, name: str, tags: TagsType | None = None) -> str: + async def create_user_identity(self, name: str, tags: Tags | None = None) -> str: return await self._create_identity(name, "User", tags=tags) - async def create_device_identity(self, name: str, tags: TagsType | None = None) -> str: + async def create_device_identity(self, name: str, tags: Tags | None = None) -> str: return await self._create_identity(name, "Device", tags=tags) async def search_identity(self, id_or_name: str) -> dict[str, Any] | None: @@ -412,12 +409,12 @@ async def fetch_ca_certificates(self) -> str: return response.text async def create_dial_service_policy( - self, name: str, service_id: str, identity_id: str, tags: TagsType | None = None + self, name: str, service_id: str, identity_id: str, tags: Tags | None = None ) -> str: return await self._create_service_policy("Dial", name, service_id, identity_id, tags) async def create_bind_service_policy( - self, name: str, service_id: str, identity_id: str, tags: TagsType | None = None + self, name: str, service_id: str, identity_id: str, tags: Tags | None = None ) -> str: return await self._create_service_policy("Bind", name, service_id, identity_id, tags) @@ -433,7 +430,7 @@ async def _create_service_policy( name: str, service_id: str, identity_id: str, - tags: TagsType | None = None, + tags: Tags | None = None, ) -> str: return await self.create( "/service-policies", @@ -451,7 +448,7 @@ async def _create_identity( self, name: str, type: Literal["User", "Device", "Default"], - tags: TagsType | None = None, + tags: Tags | None = None, ) -> str: return await self.create( "/identities", diff --git a/mrok/ziti/bootstrap.py b/mrok/ziti/bootstrap.py index 218a87d..3630be9 100644 --- a/mrok/ziti/bootstrap.py +++ b/mrok/ziti/bootstrap.py @@ -1,7 +1,8 @@ import logging from typing import Any -from mrok.ziti.api import TagsType, ZitiClientAPI, ZitiManagementAPI +from mrok.types.ziti import Tags +from mrok.ziti.api import ZitiClientAPI, ZitiManagementAPI from mrok.ziti.identities import enroll_proxy_identity logger = logging.getLogger(__name__) @@ -13,7 +14,7 @@ async def bootstrap_identity( identity_name: str, mode: str, forced: bool, - tags: TagsType | None, + tags: Tags | None, ) -> tuple[str, dict[str, Any] | None]: logger.info(f"Bootstrapping '{identity_name}' identity...") diff --git a/mrok/ziti/identities.py b/mrok/ziti/identities.py index 97ad942..d9763cf 100644 --- a/mrok/ziti/identities.py +++ b/mrok/ziti/identities.py @@ -5,8 +5,9 @@ import jwt from mrok.conf import Settings +from mrok.types.ziti import Tags from mrok.ziti import pki -from mrok.ziti.api import TagsType, ZitiClientAPI, ZitiManagementAPI +from mrok.ziti.api import ZitiClientAPI, ZitiManagementAPI from mrok.ziti.constants import ( MROK_IDENTITY_TYPE_TAG_NAME, MROK_IDENTITY_TYPE_TAG_VALUE_INSTANCE, @@ -29,7 +30,7 @@ async def register_identity( client_api: ZitiClientAPI, service_external_id: str, identity_external_id: str, - tags: TagsType | None = None, + tags: Tags | None = None, ): service_name = service_external_id.lower() identity_tags = copy.copy(tags or {}) @@ -39,7 +40,7 @@ async def register_identity( if not service: raise ServiceNotFoundError(f"A service with name `{service_external_id}` does not exists.") - identity_name = f"{identity_external_id.lower()}.{service_name}" + identity_name = identity_external_id.lower() service_policy_name = f"{identity_name}:bind" self_service_policy_name = f"self.{service_policy_name}" @@ -129,7 +130,7 @@ async def enroll_proxy_identity( mgmt_api: ZitiManagementAPI, client_api: ZitiClientAPI, identity_name: str, - tags: TagsType | None = None, + tags: Tags | None = None, ): identity = await mgmt_api.search_identity(identity_name) if identity: diff --git a/mrok/ziti/services.py b/mrok/ziti/services.py index e2ca1d3..a193348 100644 --- a/mrok/ziti/services.py +++ b/mrok/ziti/services.py @@ -2,7 +2,8 @@ from typing import Any from mrok.conf import Settings -from mrok.ziti.api import TagsType, ZitiManagementAPI +from mrok.types.ziti import Tags +from mrok.ziti.api import ZitiManagementAPI from mrok.ziti.errors import ( ConfigTypeNotFoundError, ProxyIdentityNotFoundError, @@ -14,7 +15,7 @@ async def register_service( - settings: Settings, mgmt_api: ZitiManagementAPI, external_id: str, tags: TagsType | None + settings: Settings, mgmt_api: ZitiManagementAPI, external_id: str, tags: Tags | None ) -> dict[str, Any]: service_name = external_id.lower() registered = False diff --git a/tests/agent/sidecar/test_main.py b/tests/agent/sidecar/test_main.py index e1473fd..467cfca 100644 --- a/tests/agent/sidecar/test_main.py +++ b/tests/agent/sidecar/test_main.py @@ -49,6 +49,7 @@ def test_run(mocker: MockerFixture): mocked_agent_ctor.assert_called_once_with( "ziti-identity.json", "target-addr", + events_enabled=True, workers=10, max_connections=15, max_keepalive_connections=3, diff --git a/tests/cli/agent/test_run.py b/tests/cli/agent/test_run.py index d192942..55ca98e 100644 --- a/tests/cli/agent/test_run.py +++ b/tests/cli/agent/test_run.py @@ -60,6 +60,7 @@ def test_run_sidecar( mocked_sidecar.assert_called_once_with( "ins-1234-5678-0001.json", expected_target_addr, + events_enabled=True, workers=2, max_connections=312, max_keepalive_connections=11, diff --git a/tests/conftest.py b/tests/conftest.py index a215c50..3e9fca4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,7 +14,7 @@ from pytest_httpx import HTTPXMock from mrok.conf import Settings, get_settings -from mrok.proxy.types import ASGIReceive, ASGISend, Message +from mrok.types.proxy import ASGIReceive, ASGISend, Message from tests.types import ReceiveFactory, SendFactory, SettingsFactory diff --git a/tests/proxy/test_app.py b/tests/proxy/test_app.py index b6d3ac4..425dfc2 100644 --- a/tests/proxy/test_app.py +++ b/tests/proxy/test_app.py @@ -9,7 +9,7 @@ from mrok.proxy.app import HOP_BY_HOP_HEADERS, ProxyAppBase from mrok.proxy.exceptions import ProxyError -from mrok.proxy.types import ASGIReceive, ASGISend, Message +from mrok.types.proxy import ASGIReceive, ASGISend, Message from tests.types import ReceiveFactory, SendFactory diff --git a/tests/proxy/test_asgi.py b/tests/proxy/test_asgi.py index 27bb7a5..dab6a6c 100644 --- a/tests/proxy/test_asgi.py +++ b/tests/proxy/test_asgi.py @@ -4,7 +4,7 @@ from pytest_mock import MockerFixture from mrok.proxy.asgi import ASGIAppWrapper -from mrok.proxy.types import Message +from mrok.types.proxy import Message from tests.types import ReceiveFactory, SendFactory diff --git a/tests/proxy/test_config.py b/tests/proxy/test_config.py index c4366d6..bd608d8 100644 --- a/tests/proxy/test_config.py +++ b/tests/proxy/test_config.py @@ -1,19 +1,17 @@ import pytest from pytest_mock import MockerFixture -from mrok.proxy.config import MrokBackendConfig -from mrok.proxy.lifespan import MrokLifespan -from mrok.proxy.protocol import MrokHttpToolsProtocol +from mrok.proxy.ziticorn import BackendConfig, HttpToolsProtocol, Lifespan def test_backend_config_init(ziti_identity_json: dict, ziti_identity_file: str): async def fake_asgi_app(scope, receive, send): pass - config = MrokBackendConfig(fake_asgi_app, ziti_identity_file) + config = BackendConfig(fake_asgi_app, ziti_identity_file) assert config.app == fake_asgi_app assert config.loop == "asyncio" - assert config.http == MrokHttpToolsProtocol + assert config.http == HttpToolsProtocol assert config.lifespan == "auto" assert config.backlog == 2048 assert config.service_name == ziti_identity_json["mrok"]["extension"] @@ -26,12 +24,12 @@ def test_backend_config_load(ziti_identity_file: str): async def fake_asgi_app(scope, receive, send): pass - config = MrokBackendConfig(fake_asgi_app, ziti_identity_file) + config = BackendConfig(fake_asgi_app, ziti_identity_file) config.load() assert config.app == fake_asgi_app assert config.loop == "asyncio" - assert config.http_protocol_class == MrokHttpToolsProtocol - assert config.lifespan_class == MrokLifespan + assert config.http_protocol_class == HttpToolsProtocol + assert config.lifespan_class == Lifespan def test_backend_config_bind_socket( @@ -41,14 +39,14 @@ def test_backend_config_bind_socket( mocked_socket = mocker.MagicMock() mocked_ctx.bind.return_value = mocked_socket mocked_openziti_load = mocker.patch( - "mrok.proxy.config.openziti.load", + "mrok.proxy.ziticorn.openziti.load", return_value=(mocked_ctx, 0), ) async def fake_asgi_app(scope, receive, send): pass - config = MrokBackendConfig( + config = BackendConfig( fake_asgi_app, ziti_identity_file, backlog=4096, ziti_load_timeout_ms=1234 ) assert config.bind_socket() == mocked_socket @@ -64,14 +62,14 @@ def test_backend_config_bind_socket_ziti_error(mocker: MockerFixture, ziti_ident mocked_socket = mocker.MagicMock() mocked_ctx.bind.return_value = mocked_socket mocker.patch( - "mrok.proxy.config.openziti.load", + "mrok.proxy.ziticorn.openziti.load", return_value=(None, 1), ) async def fake_asgi_app(scope, receive, send): pass - config = MrokBackendConfig(fake_asgi_app, ziti_identity_file, backlog=4096) + config = BackendConfig(fake_asgi_app, ziti_identity_file, backlog=4096) with pytest.raises(RuntimeError) as cv: config.bind_socket() diff --git a/tests/proxy/test_event_publisher.py b/tests/proxy/test_event_publisher.py new file mode 100644 index 0000000..4733467 --- /dev/null +++ b/tests/proxy/test_event_publisher.py @@ -0,0 +1,190 @@ +import asyncio +import contextlib + +import pytest +import zmq +from pytest_mock import MockerFixture + +from mrok.proxy.event_publisher import EventPublisher +from mrok.proxy.models import ( + DataTransferMetrics, + Event, + HTTPHeaders, + HTTPRequest, + HTTPResponse, + Identity, + ProcessMetrics, + RequestsMetrics, + ResponseTimeMetrics, + Status, + WorkerMetrics, +) + + +async def test_publish_metrics_event( + mocker: MockerFixture, + ziti_identity_file: str, +): + identity = Identity.load_from_file(ziti_identity_file) + event_publisher = EventPublisher( + worker_id="my-worker-id", + meta=identity.mrok, + ) + + metrics_snapshot = WorkerMetrics( + worker_id="my-worker-id", + data_transfer=DataTransferMetrics( + bytes_in=1000, + bytes_out=2000, + ), + requests=RequestsMetrics(rps=123, total=1000, successful=10, failed=30), + response_time=ResponseTimeMetrics( + avg=10, + min=1, + max=30, + p50=11, + p90=22, + p99=11, + ), + process=ProcessMetrics(cpu=12, mem=22), + ) + + event_publisher._publisher = mocker.AsyncMock() + event_publisher._metrics_collector = mocker.AsyncMock() + event_publisher._metrics_collector.snapshot.return_value = metrics_snapshot # type: ignore + + task = asyncio.create_task(event_publisher.publish_metrics_event()) + await asyncio.sleep(0.1) + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task + + metrics_events = Event( + type="status", + data=Status( + meta=identity.mrok, + metrics=metrics_snapshot, + ), + ) + event_publisher._publisher.send_string.assert_called_once_with( # type: ignore + metrics_events.model_dump_json() + ) + + +@pytest.mark.asyncio +async def test_publish_response_event( + mocker: MockerFixture, + ziti_identity_file: str, +): + identity = Identity.load_from_file(ziti_identity_file) + event_publisher = EventPublisher( + worker_id="my-worker-id", + meta=identity.mrok, + ) + + resp = HTTPResponse( + type="response", + headers=HTTPHeaders.from_asgi([(b"content-type", b"text/plain")]), + request=HTTPRequest( + method="GET", + url="url", + query_string=b"", + headers=HTTPHeaders.from_asgi([(b"content-type", b"application/json")]), + start_time=0, + ), + status=200, + duration=20.5, + ) + + event_publisher._publisher = mocker.AsyncMock() + + await event_publisher.publish_response_event(resp) + + resp_event = Event(type="response", data=resp) + event_publisher._publisher.send_string.assert_awaited_once_with( # type: ignore + resp_event.model_dump_json() + ) + + +@pytest.mark.asyncio +async def test_lifespan( + mocker: MockerFixture, + ziti_identity_file: str, +): + m_on_startup = mocker.patch.object(EventPublisher, "on_startup") + m_on_shutdown = mocker.patch.object(EventPublisher, "on_shutdown") + + identity = Identity.load_from_file(ziti_identity_file) + event_publisher = EventPublisher( + worker_id="my-worker-id", + meta=identity.mrok, + ) + + m_app = mocker.AsyncMock() + async with event_publisher.lifespan(m_app): + pass + + m_on_startup.assert_awaited_once() + m_on_shutdown.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_on_startup( + mocker: MockerFixture, + ziti_identity_file: str, +): + m_publisher = mocker.MagicMock() + m_zmq_ctx = mocker.MagicMock() + m_zmq_ctx.socket.return_value = m_publisher + m_zmq_ctx_ctor = mocker.MagicMock() + m_zmq_ctx_ctor.return_value = m_zmq_ctx + mocker.patch("mrok.proxy.event_publisher.zmq.asyncio.Context", m_zmq_ctx_ctor) + + m_metrics = mocker.MagicMock() + m_metricscollector_ctor = mocker.patch( + "mrok.proxy.event_publisher.MetricsCollector", return_value=m_metrics + ) + m_publish_metrics_event = mocker.patch.object(EventPublisher, "publish_metrics_event") + + identity = Identity.load_from_file(ziti_identity_file) + event_publisher = EventPublisher( + worker_id="my-worker-id", + meta=identity.mrok, + event_publisher_port=8282, + ) + + await event_publisher.on_startup() + m_metricscollector_ctor.assert_called_once_with("my-worker-id") + assert event_publisher._metrics_collector == m_metrics + assert event_publisher._zmq_ctx == m_zmq_ctx + m_zmq_ctx.socket.assert_called_once_with(zmq.PUB) + m_publisher.connect.assert_called_once_with("tcp://localhost:8282") + await asyncio.sleep(0.001) + m_publish_metrics_event.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_on_shutdown( + mocker: MockerFixture, + ziti_identity_file: str, +): + identity = Identity.load_from_file(ziti_identity_file) + event_publisher = EventPublisher( + worker_id="my-worker-id", + meta=identity.mrok, + ) + + async def my_coro(): + while True: + await asyncio.sleep(5) + + task = asyncio.create_task(my_coro()) + + event_publisher._publish_task = task # type: ignore + event_publisher._publisher = mocker.MagicMock() + event_publisher._zmq_ctx = mocker.MagicMock() + + await event_publisher.on_shutdown() + assert task.cancelled() + event_publisher._publisher.close.assert_called_once() # type: ignore + event_publisher._zmq_ctx.term.assert_called_once() # type: ignore diff --git a/tests/proxy/test_lifespan.py b/tests/proxy/test_lifespan.py index 57c6574..a0618bb 100644 --- a/tests/proxy/test_lifespan.py +++ b/tests/proxy/test_lifespan.py @@ -1,11 +1,11 @@ from pytest_mock import MockerFixture -from mrok.proxy.lifespan import MrokLifespan +from mrok.proxy.ziticorn import Lifespan def test_lifespan(mocker: MockerFixture): mocked_config = mocker.MagicMock(loaded=True) - lifespan = MrokLifespan(mocked_config) + lifespan = Lifespan(mocked_config) assert lifespan.config == mocked_config assert lifespan.logger.name == "mrok.proxy" diff --git a/tests/proxy/test_master.py b/tests/proxy/test_master.py index bf5f8a7..781c3f7 100644 --- a/tests/proxy/test_master.py +++ b/tests/proxy/test_master.py @@ -68,7 +68,7 @@ def test_start_uvicorn_worker_hook( m_app, "my-id-file.json", events_enabled=False, - events_publisher_port=2233, + event_publisher_port=2233, metrics_interval=24.0, ) m_worker.run.assert_called_once() diff --git a/tests/proxy/test_metrics.py b/tests/proxy/test_metrics.py index 8f3555f..dd4de5a 100644 --- a/tests/proxy/test_metrics.py +++ b/tests/proxy/test_metrics.py @@ -1,7 +1,8 @@ import pytest from pytest_mock import MockerFixture -from mrok.proxy.metrics import ProcessMetrics, WorkerMetricsCollector, get_process_metrics +from mrok.proxy.metrics import MetricsCollector, get_process_metrics +from mrok.proxy.models import ProcessMetrics @pytest.mark.asyncio @@ -28,7 +29,7 @@ async def test_worker_metrics_collector( mocker.patch( "mrok.proxy.metrics.get_process_metrics", return_value=ProcessMetrics(cpu=7.3, mem=44.1) ) - collector = WorkerMetricsCollector("my-worker-id") + collector = MetricsCollector("my-worker-id") begin = await collector.on_request_start({}) await collector.on_request_body(23) await collector.on_request_body(32) diff --git a/tests/proxy/test_middlewares.py b/tests/proxy/test_middleware.py similarity index 95% rename from tests/proxy/test_middlewares.py rename to tests/proxy/test_middleware.py index e465962..9aa703b 100644 --- a/tests/proxy/test_middlewares.py +++ b/tests/proxy/test_middleware.py @@ -3,9 +3,9 @@ import pytest from pytest_mock import MockerFixture -from mrok.proxy.datastructures import HTTPResponse -from mrok.proxy.middlewares import CaptureMiddleware, MetricsMiddleware -from mrok.proxy.types import Message +from mrok.proxy.middleware import CaptureMiddleware, MetricsMiddleware +from mrok.proxy.models import HTTPResponse +from mrok.types.proxy import Message from tests.types import ReceiveFactory, SendFactory @@ -69,7 +69,7 @@ async def test_capture( receive_factory: ReceiveFactory, send_factory: SendFactory, ): - mocker.patch("mrok.proxy.middlewares.time.time", side_effect=[7, 25]) + mocker.patch("mrok.proxy.middleware.time.time", side_effect=[7, 25]) class MockApp: async def __call__(self, scope, receive, send): diff --git a/tests/proxy/test_protocol.py b/tests/proxy/test_protocol.py index 5d7c51d..6fba587 100644 --- a/tests/proxy/test_protocol.py +++ b/tests/proxy/test_protocol.py @@ -1,11 +1,12 @@ from pytest_mock import MockerFixture +from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol as UvHttpToolsProtocol -from mrok.proxy.protocol import HttpToolsProtocol, MrokHttpToolsProtocol +from mrok.proxy.ziticorn import HttpToolsProtocol def test_protocol(mocker: MockerFixture): - mocked_super_init = mocker.patch.object(HttpToolsProtocol, "__init__") - proto = MrokHttpToolsProtocol("10", test=True) + mocked_super_init = mocker.patch.object(UvHttpToolsProtocol, "__init__") + proto = HttpToolsProtocol("10", test=True) mocked_super_init.assert_called_once_with("10", test=True) assert proto.access_logger.name == "mrok.access" assert proto.logger.name == "mrok.proxy" diff --git a/tests/proxy/test_server.py b/tests/proxy/test_server.py index f2a1d45..2f476b3 100644 --- a/tests/proxy/test_server.py +++ b/tests/proxy/test_server.py @@ -1,25 +1,22 @@ import pytest from pytest_mock import MockerFixture -from mrok.proxy.config import MrokBackendConfig -from mrok.proxy.server import MrokServer +from mrok.proxy.ziticorn import BackendConfig, Server @pytest.mark.asyncio async def test_serve(mocker: MockerFixture): mocked_socket = mocker.MagicMock() - mocker.patch.object(MrokBackendConfig, "bind_socket", return_value=mocked_socket) - mocker.patch.object( - MrokBackendConfig, "get_identity_info", return_value=("ext", "ins.ext", "ins") - ) - mocked_inner_serve = mocker.patch.object(MrokServer, "_serve") + mocker.patch.object(BackendConfig, "bind_socket", return_value=mocked_socket) + mocker.patch.object(BackendConfig, "get_identity_info", return_value=("ext", "ins.ext", "ins")) + mocked_inner_serve = mocker.patch.object(Server, "_serve") async def fake_asgi_app(scope, receive, send): pass - config = MrokBackendConfig(fake_asgi_app, "ziti-identity.json") + config = BackendConfig(fake_asgi_app, "ziti-identity.json") - server = MrokServer(config) + server = Server(config) await server.serve() mocked_inner_serve.assert_awaited_once_with([mocked_socket]) @@ -27,18 +24,16 @@ async def fake_asgi_app(scope, receive, send): @pytest.mark.asyncio async def test_serve_with_socket(mocker: MockerFixture): mocked_socket = mocker.MagicMock() - mocked_bind = mocker.patch.object(MrokBackendConfig, "bind_socket") - mocker.patch.object( - MrokBackendConfig, "get_identity_info", return_value=("ext", "ins.ext", "ins") - ) - mocked_inner_serve = mocker.patch.object(MrokServer, "_serve") + mocked_bind = mocker.patch.object(BackendConfig, "bind_socket") + mocker.patch.object(BackendConfig, "get_identity_info", return_value=("ext", "ins.ext", "ins")) + mocked_inner_serve = mocker.patch.object(Server, "_serve") async def fake_asgi_app(scope, receive, send): pass - config = MrokBackendConfig(fake_asgi_app, "ziti-identity.json") + config = BackendConfig(fake_asgi_app, "ziti-identity.json") - server = MrokServer(config) + server = Server(config) await server.serve([mocked_socket]) mocked_inner_serve.assert_awaited_once_with([mocked_socket]) mocked_bind.assert_not_called() diff --git a/tests/proxy/test_streams.py b/tests/proxy/test_stream.py similarity index 96% rename from tests/proxy/test_streams.py rename to tests/proxy/test_stream.py index e0489bc..97f7b84 100644 --- a/tests/proxy/test_streams.py +++ b/tests/proxy/test_stream.py @@ -3,7 +3,7 @@ import pytest from pytest_mock import MockerFixture -from mrok.proxy.streams import AIONetworkStream, ASGIRequestBodyStream +from mrok.proxy.stream import AIONetworkStream, ASGIRequestBodyStream @pytest.mark.asyncio @@ -75,7 +75,7 @@ def test_aio_network_stream_extra_info_is_readable( m_writer.transport = mocker.MagicMock() m_writer.transport.get_extra_info.return_value = m_sock - m_is_readable = mocker.patch("mrok.proxy.streams.is_readable", return_value=readable) + m_is_readable = mocker.patch("mrok.proxy.stream.is_readable", return_value=readable) aions = AIONetworkStream(mocker.MagicMock(), m_writer) diff --git a/tests/proxy/test_worker.py b/tests/proxy/test_worker.py index 7e1a64e..9a939e4 100644 --- a/tests/proxy/test_worker.py +++ b/tests/proxy/test_worker.py @@ -1,24 +1,7 @@ -import asyncio -import contextlib - -import pytest -import zmq from pytest_mock import MockerFixture from mrok.proxy.asgi import ASGIAppWrapper -from mrok.proxy.datastructures import ( - DataTransferMetrics, - Event, - HTTPHeaders, - HTTPRequest, - HTTPResponse, - ProcessMetrics, - RequestsMetrics, - ResponseTimeMetrics, - Status, - WorkerMetrics, -) -from mrok.proxy.middlewares import CaptureMiddleware, MetricsMiddleware +from mrok.proxy.middleware import CaptureMiddleware, MetricsMiddleware from mrok.proxy.worker import Worker from tests.types import SettingsFactory @@ -35,11 +18,12 @@ def test_setup_app( ) app = worker.setup_app() assert isinstance(app, ASGIAppWrapper) - assert app.lifespan == worker.lifespan + assert worker._event_publisher is not None + assert app.lifespan == worker._event_publisher.lifespan assert app.middlware[0].cls == MetricsMiddleware - assert app.middlware[0].args[0] == worker._metrics_collector + assert app.middlware[0].args[0] == worker._event_publisher._metrics_collector assert app.middlware[1].cls == CaptureMiddleware - assert app.middlware[1].args[0] == worker.publish_response_event + assert app.middlware[1].args[0] == worker._event_publisher.publish_response_event def test_setup_app_events_disabled( @@ -70,12 +54,10 @@ def test_run( m_setup_logging = mocker.patch("mrok.proxy.worker.setup_logging") m_mrokconfig = mocker.MagicMock() - m_mrokconfig_ctor = mocker.patch( - "mrok.proxy.worker.MrokBackendConfig", return_value=m_mrokconfig - ) + m_mrokconfig_ctor = mocker.patch("mrok.proxy.worker.BackendConfig", return_value=m_mrokconfig) m_server = mocker.MagicMock() - m_server_ctor = mocker.patch("mrok.proxy.worker.MrokServer", return_value=m_server) + m_server_ctor = mocker.patch("mrok.proxy.worker.Server", return_value=m_server) m_app = mocker.MagicMock() mocker.patch.object(Worker, "setup_app", return_value=m_app) @@ -90,163 +72,3 @@ def test_run( m_mrokconfig_ctor.assert_called_once_with(m_app, ziti_identity_file) m_server_ctor.assert_called_once_with(m_mrokconfig) m_server.run.assert_called_once() - - -@pytest.mark.asyncio -async def test_lifespan( - mocker: MockerFixture, - ziti_identity_file: str, -): - m_app = mocker.AsyncMock() - m_on_startup = mocker.patch.object(Worker, "on_startup") - m_on_shutdown = mocker.patch.object(Worker, "on_shutdown") - - worker = Worker( - "my-worker-id", - m_app, - ziti_identity_file, - ) - - async with worker.lifespan(m_app): - pass - - m_on_startup.assert_awaited_once() - m_on_shutdown.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_on_startup( - mocker: MockerFixture, - ziti_identity_file: str, -): - m_publisher = mocker.MagicMock() - m_zmq_ctx = mocker.MagicMock() - m_zmq_ctx.socket.return_value = m_publisher - m_zmq_ctx_ctor = mocker.MagicMock() - m_zmq_ctx_ctor.return_value = m_zmq_ctx - mocker.patch("mrok.proxy.worker.zmq.asyncio.Context", m_zmq_ctx_ctor) - - m_metrics = mocker.MagicMock() - m_metricscollector_ctor = mocker.patch( - "mrok.proxy.worker.WorkerMetricsCollector", return_value=m_metrics - ) - m_publish_metrics_event = mocker.patch.object(Worker, "publish_metrics_event") - - m_app = mocker.AsyncMock() - worker = Worker("my-worker-id", m_app, ziti_identity_file, events_publisher_port=8282) - await worker.on_startup() - m_metricscollector_ctor.assert_called_once_with("my-worker-id") - assert worker._metrics_collector == m_metrics - assert worker._zmq_ctx == m_zmq_ctx - m_zmq_ctx.socket.assert_called_once_with(zmq.PUB) - assert worker._events_publisher == m_publisher - m_publisher.connect.assert_called_once_with("tcp://localhost:8282") - await asyncio.sleep(0.001) - m_publish_metrics_event.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_on_shutdown( - mocker: MockerFixture, - ziti_identity_file: str, -): - m_app = mocker.AsyncMock() - worker = Worker("my-worker-id", m_app, ziti_identity_file) - - async def my_coro(): - while True: - await asyncio.sleep(5) - - task = asyncio.create_task(my_coro()) - - worker._events_publish_task = task # type: ignore - worker._zmq_ctx = mocker.MagicMock() - worker._events_publisher = mocker.MagicMock() - - await worker.on_shutdown() - assert task.cancelled() - worker._events_publisher.close.assert_called_once() # type: ignore - worker._zmq_ctx.term.assert_called_once() # type: ignore - - -async def test_publish_metrics_event( - mocker: MockerFixture, - ziti_identity_file: str, -): - m_app = mocker.AsyncMock() - worker = Worker( - "my-worker-id", - m_app, - ziti_identity_file, - ) - - metrics_snapshot = WorkerMetrics( - worker_id="my-worker-id", - data_transfer=DataTransferMetrics( - bytes_in=1000, - bytes_out=2000, - ), - requests=RequestsMetrics(rps=123, total=1000, successful=10, failed=30), - response_time=ResponseTimeMetrics( - avg=10, - min=1, - max=30, - p50=11, - p90=22, - p99=11, - ), - process=ProcessMetrics(cpu=12, mem=22), - ) - - worker._metrics_collector = mocker.AsyncMock() - worker._metrics_collector.snapshot.return_value = metrics_snapshot # type: ignore - worker._events_publisher = mocker.AsyncMock() - - task = asyncio.create_task(worker.publish_metrics_event()) - await asyncio.sleep(0.1) - task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await task - - metrics_events = Event( - type="status", - data=Status( - meta=worker._identity.mrok, - metrics=metrics_snapshot, - ), - ) - worker._events_publisher.send_string.assert_called_once_with(metrics_events.model_dump_json()) # type: ignore - - -@pytest.mark.asyncio -async def test_publish_response_event( - mocker: MockerFixture, - ziti_identity_file: str, -): - m_app = mocker.AsyncMock() - worker = Worker( - "my-worker-id", - m_app, - ziti_identity_file, - ) - resp = HTTPResponse( - type="response", - headers=HTTPHeaders.from_asgi([(b"content-type", b"text/plain")]), - request=HTTPRequest( - method="GET", - url="url", - query_string=b"", - headers=HTTPHeaders.from_asgi([(b"content-type", b"application/json")]), - start_time=0, - ), - status=200, - duration=20.5, - ) - - worker._events_publisher = mocker.AsyncMock() - - await worker.publish_response_event(resp) - - resp_event = Event(type="response", data=resp) - - worker._events_publisher.send_string.assert_awaited_once_with(resp_event.model_dump_json()) # type: ignore diff --git a/tests/types.py b/tests/types.py index a8e9315..db206fe 100644 --- a/tests/types.py +++ b/tests/types.py @@ -6,7 +6,7 @@ from textual.pilot import Pilot from mrok.conf import Settings -from mrok.proxy.types import ASGIReceive, ASGISend, Message +from mrok.types.proxy import ASGIReceive, ASGISend, Message class SnapCompare(Protocol): diff --git a/tests/ziti/test_api.py b/tests/ziti/test_api.py index 797e223..fd32389 100644 --- a/tests/ziti/test_api.py +++ b/tests/ziti/test_api.py @@ -5,8 +5,8 @@ from pytest_httpx import HTTPXMock from pytest_mock import MockerFixture +from mrok.types.ziti import Tags from mrok.ziti.api import ( - TagsType, ZitiAuthError, ZitiBadRequestError, ZitiClientAPI, @@ -22,7 +22,7 @@ async def test_create( settings_factory: SettingsFactory, httpx_mock: HTTPXMock, - tags: TagsType | None, + tags: Tags | None, ): settings = settings_factory() expected_body = { diff --git a/tests/ziti/test_identities.py b/tests/ziti/test_identities.py index 5258e90..e04bcef 100644 --- a/tests/ziti/test_identities.py +++ b/tests/ziti/test_identities.py @@ -69,13 +69,11 @@ async def test_register_instance(mocker: MockerFixture, settings_factory: Settin } assert mocked_mgmt_api.search_service.mock_calls[0].args[0] == "ext-1234-5678" - assert ( - mocked_mgmt_api.search_service.mock_calls[1].args[0] == "ins-1234-5678-0001.ext-1234-5678" - ) + assert mocked_mgmt_api.search_service.mock_calls[1].args[0] == "ins-1234-5678-0001" - mocked_mgmt_api.search_identity.assert_awaited_once_with("ins-1234-5678-0001.ext-1234-5678") + mocked_mgmt_api.search_identity.assert_awaited_once_with("ins-1234-5678-0001") mocked_mgmt_api.create_user_identity.assert_awaited_once_with( - "ins-1234-5678-0001.ext-1234-5678", + "ins-1234-5678-0001", tags={ MROK_SERVICE_TAG_NAME: "ext-1234-5678", "account": "ACC-1234", @@ -105,29 +103,29 @@ async def test_register_instance(mocker: MockerFixture, settings_factory: Settin "account": "ACC-1234", } assert identity_json["mrok"]["domain"] == settings.proxy.domain - assert identity_json["mrok"]["identity"] == "ins-1234-5678-0001.ext-1234-5678" + assert identity_json["mrok"]["identity"] == "ins-1234-5678-0001" assert identity_json["mrok"]["extension"] == "EXT-1234-5678" assert identity_json["mrok"]["instance"] == "INS-1234-5678-0001" mocked_register_service.assert_called_once_with( settings, mocked_mgmt_api, - "ins-1234-5678-0001.ext-1234-5678", + "ins-1234-5678-0001", {"account": "ACC-1234"}, ) assert mocked_mgmt_api.create_bind_service_policy.mock_calls[0].args == ( - "ins-1234-5678-0001.ext-1234-5678:bind", + "ins-1234-5678-0001:bind", "svc1", "identity-id", ) assert mocked_mgmt_api.create_bind_service_policy.mock_calls[1].args == ( - "self.ins-1234-5678-0001.ext-1234-5678:bind", + "self.ins-1234-5678-0001:bind", "self-service-id", "identity-id", ) mocked_mgmt_api.create_router_policy.assert_awaited_once_with( - "ins-1234-5678-0001.ext-1234-5678", + "ins-1234-5678-0001", "identity-id", ) @@ -197,12 +195,10 @@ async def test_register_instance_identity_exists( ) assert mocked_mgmt_api.search_service.mock_calls[0].args[0] == "ext-1234-5678" - assert ( - mocked_mgmt_api.search_service.mock_calls[1].args[0] == "ins-1234-5678-0001.ext-1234-5678" - ) - mocked_mgmt_api.search_identity.assert_awaited_once_with("ins-1234-5678-0001.ext-1234-5678") + assert mocked_mgmt_api.search_service.mock_calls[1].args[0] == "ins-1234-5678-0001" + mocked_mgmt_api.search_identity.assert_awaited_once_with("ins-1234-5678-0001") mocked_mgmt_api.create_user_identity.assert_awaited_once_with( - "ins-1234-5678-0001.ext-1234-5678", + "ins-1234-5678-0001", tags={ MROK_SERVICE_TAG_NAME: "ext-1234-5678", "account": "ACC-1234", @@ -230,31 +226,29 @@ async def test_register_instance_identity_exists( mocked_register_service.assert_not_awaited() assert mocked_mgmt_api.create_bind_service_policy.mock_calls[0].args == ( - "ins-1234-5678-0001.ext-1234-5678:bind", + "ins-1234-5678-0001:bind", "svc1", "identity-id", ) assert mocked_mgmt_api.create_bind_service_policy.mock_calls[1].args == ( - "self.ins-1234-5678-0001.ext-1234-5678:bind", + "self.ins-1234-5678-0001:bind", "self-service-id", "identity-id", ) mocked_mgmt_api.create_router_policy.assert_awaited_once_with( - "ins-1234-5678-0001.ext-1234-5678", + "ins-1234-5678-0001", "identity-id", ) assert mocked_mgmt_api.search_service_policy.mock_calls[0].args[0] == ( - "ins-1234-5678-0001.ext-1234-5678:bind" + "ins-1234-5678-0001:bind" ) assert mocked_mgmt_api.search_service_policy.mock_calls[1].args[0] == ( - "self.ins-1234-5678-0001.ext-1234-5678:bind" + "self.ins-1234-5678-0001:bind" ) assert mocked_mgmt_api.delete_service_policy.mock_calls[0].args[0] == "service-policy-id" assert mocked_mgmt_api.delete_service_policy.mock_calls[1].args[0] == "self-service-policy-id" - mocked_mgmt_api.search_router_policy.assert_awaited_once_with( - "ins-1234-5678-0001.ext-1234-5678" - ) + mocked_mgmt_api.search_router_policy.assert_awaited_once_with("ins-1234-5678-0001") mocked_mgmt_api.delete_router_policy.assert_awaited_once_with("router-policy-id") mocked_mgmt_api.delete_identity.assert_awaited_once_with("identity-id") @@ -298,12 +292,10 @@ async def test_register_instance_identity_exists_service_router_doesnt( ) assert mocked_mgmt_api.search_service.mock_calls[0].args[0] == "ext-1234-5678" - assert ( - mocked_mgmt_api.search_service.mock_calls[1].args[0] == "ins-1234-5678-0001.ext-1234-5678" - ) - mocked_mgmt_api.search_identity.assert_awaited_once_with("ins-1234-5678-0001.ext-1234-5678") + assert mocked_mgmt_api.search_service.mock_calls[1].args[0] == "ins-1234-5678-0001" + mocked_mgmt_api.search_identity.assert_awaited_once_with("ins-1234-5678-0001") mocked_mgmt_api.create_user_identity.assert_awaited_once_with( - "ins-1234-5678-0001.ext-1234-5678", + "ins-1234-5678-0001", tags={ MROK_SERVICE_TAG_NAME: "ext-1234-5678", "account": "ACC-1234", @@ -331,34 +323,32 @@ async def test_register_instance_identity_exists_service_router_doesnt( mocked_register_service.assert_called_once_with( settings, mocked_mgmt_api, - "ins-1234-5678-0001.ext-1234-5678", + "ins-1234-5678-0001", {"account": "ACC-1234"}, ) assert mocked_mgmt_api.create_bind_service_policy.mock_calls[0].args == ( - "ins-1234-5678-0001.ext-1234-5678:bind", + "ins-1234-5678-0001:bind", "svc1", "identity-id", ) assert mocked_mgmt_api.create_bind_service_policy.mock_calls[1].args == ( - "self.ins-1234-5678-0001.ext-1234-5678:bind", + "self.ins-1234-5678-0001:bind", "self-service-id", "identity-id", ) mocked_mgmt_api.create_router_policy.assert_awaited_once_with( - "ins-1234-5678-0001.ext-1234-5678", + "ins-1234-5678-0001", "identity-id", ) assert mocked_mgmt_api.search_service_policy.mock_calls[0].args[0] == ( - "ins-1234-5678-0001.ext-1234-5678:bind" + "ins-1234-5678-0001:bind" ) assert mocked_mgmt_api.search_service_policy.mock_calls[1].args[0] == ( - "self.ins-1234-5678-0001.ext-1234-5678:bind" + "self.ins-1234-5678-0001:bind" ) mocked_mgmt_api.delete_service_policy.assert_not_awaited() - mocked_mgmt_api.search_router_policy.assert_awaited_once_with( - "ins-1234-5678-0001.ext-1234-5678" - ) + mocked_mgmt_api.search_router_policy.assert_awaited_once_with("ins-1234-5678-0001") mocked_mgmt_api.delete_router_policy.assert_not_awaited() mocked_mgmt_api.delete_identity.assert_awaited_once_with("identity-id") diff --git a/tests/ziti/test_services.py b/tests/ziti/test_services.py index 4939abf..66a743a 100644 --- a/tests/ziti/test_services.py +++ b/tests/ziti/test_services.py @@ -1,7 +1,7 @@ import pytest from pytest_mock import MockerFixture -from mrok.ziti.api import TagsType +from mrok.types.ziti import Tags from mrok.ziti.services import ( ConfigTypeNotFoundError, ProxyIdentityNotFoundError, @@ -27,7 +27,7 @@ async def test_register_extension(mocker: MockerFixture, settings_factory: Setti mocked_api.search_service_policy.return_value = None mocked_api.search_service_router_policy.return_value = None - tags: TagsType = {"tag": "my-tag"} + tags: Tags = {"tag": "my-tag"} await register_service(settings, mocked_api, "EXT-1234", tags) @@ -101,7 +101,7 @@ async def test_register_extension_config_exists( mocked_api.search_service_policy.return_value = None mocked_api.search_service_router_policy.return_value = None - tags: TagsType = {"tag": "my-tag"} + tags: Tags = {"tag": "my-tag"} await register_service(settings, mocked_api, "EXT-1234", tags) @@ -141,7 +141,7 @@ async def test_register_extension_service_exists( mocked_api.search_service_policy.return_value = None mocked_api.search_service_router_policy.return_value = None - tags: TagsType = {"tag": "my-tag"} + tags: Tags = {"tag": "my-tag"} await register_service(settings, mocked_api, "EXT-1234", tags) @@ -182,7 +182,7 @@ async def test_register_extension_dial_policy_exists( mocked_api.search_service_policy.return_value = {"id": "policy_id"} mocked_api.search_service_router_policy.return_value = None - tags: TagsType = {"tag": "my-tag"} + tags: Tags = {"tag": "my-tag"} await register_service(settings, mocked_api, "EXT-1234", tags) @@ -218,7 +218,7 @@ async def test_register_extension_router_policy_exists( mocked_api.search_service_policy.return_value = {"id": "policy_id"} mocked_api.search_service_router_policy.return_value = {"id": "policy_id"} - tags: TagsType = {"tag": "my-tag"} + tags: Tags = {"tag": "my-tag"} with pytest.raises(ServiceAlreadyRegisteredError) as cv: await register_service(settings, mocked_api, "EXT-1234", tags)