From 96db993896c396163e992d87be2fa9cfd17e00ee Mon Sep 17 00:00:00 2001 From: Francesco Faraone Date: Thu, 11 Dec 2025 11:53:16 +0100 Subject: [PATCH] move to aio cache --- mrok/proxy/app.py | 36 +++---- mrok/proxy/streams.py | 6 +- mrok/proxy/types.py | 5 +- mrok/proxy/ziti.py | 224 ++++++++++++++++-------------------------- pyproject.toml | 1 + uv.lock | 11 +++ 6 files changed, 115 insertions(+), 168 deletions(-) diff --git a/mrok/proxy/app.py b/mrok/proxy/app.py index fb22778..1cc624e 100644 --- a/mrok/proxy/app.py +++ b/mrok/proxy/app.py @@ -1,5 +1,4 @@ import logging -import os from pathlib import Path from mrok.conf import get_settings @@ -35,37 +34,30 @@ def __init__( self._conn_manager = ZitiConnectionManager( identity_file, ttl_seconds=ziti_connection_ttl_seconds, - purge_interval=ziti_conn_cache_purge_interval_seconds, + cleanup_interval=ziti_conn_cache_purge_interval_seconds, ) - def get_target_from_header(self, name: str, headers: dict[str, str]) -> str: - header_value = headers.get(name) - if not header_value: - raise ProxyError( - f"Header {name} not found!", - ) - if ":" in header_value: - header_value, _ = header_value.split(":", 1) - if not header_value.endswith(self._proxy_wildcard_domain): - raise ProxyError(f"Unexpected value for {name} header: `{header_value}`.") - - return header_value[: -len(self._proxy_wildcard_domain)] + def get_target_from_header(self, headers: dict[str, str], name: str) -> str | None: + header_value = headers.get(name, "") + if self._proxy_wildcard_domain in header_value: + if ":" in header_value: + header_value, _ = header_value.split(":", 1) + return header_value[: -len(self._proxy_wildcard_domain)] def get_target_name(self, headers: dict[str, str]) -> str: - try: - return self.get_target_from_header("x-forwared-for", headers) - except ProxyError as pe: - logger.warning(pe) - return self.get_target_from_header("host", headers) + target = self.get_target_from_header(headers, "x-forwarded-host") + if not target: + target = self.get_target_from_header(headers, "host") + if not target: + raise ProxyError("Neither Host nor X-Forwarded-Host contain a valid target name") + return target async def startup(self): setup_logging(get_settings()) await self._conn_manager.start() - logger.info(f"Proxy app startup completed: {os.getpid()}") async def shutdown(self): await self._conn_manager.stop() - logger.info(f"Proxy app shutdown completed: {os.getpid()}") async def select_backend( self, @@ -74,4 +66,4 @@ async def select_backend( ) -> tuple[StreamReader, StreamWriter] | tuple[None, None]: target_name = self.get_target_name(headers) - return await self._conn_manager.get(target_name) + return await self._conn_manager.get_or_create(target_name) diff --git a/mrok/proxy/streams.py b/mrok/proxy/streams.py index 1647449..074e627 100644 --- a/mrok/proxy/streams.py +++ b/mrok/proxy/streams.py @@ -1,13 +1,13 @@ import asyncio -from mrok.proxy.types import ConnectionCache, ConnectionKey +from mrok.proxy.types import ConnectionCache class CachedStreamReader: def __init__( self, reader: asyncio.StreamReader, - key: ConnectionKey, + key: str, manager: ConnectionCache, ): self._reader = reader @@ -77,7 +77,7 @@ class CachedStreamWriter: def __init__( self, writer: asyncio.StreamWriter, - key: ConnectionKey, + key: str, manager: ConnectionCache, ): self._writer = writer diff --git a/mrok/proxy/types.py b/mrok/proxy/types.py index 2a1f22c..9d5d93e 100644 --- a/mrok/proxy/types.py +++ b/mrok/proxy/types.py @@ -4,9 +4,8 @@ from mrok.http.types import StreamReader, StreamWriter -ConnectionKey = tuple[str, str | None] -CachedStream = tuple[StreamReader, StreamWriter] +StreamPair = tuple[StreamReader, StreamWriter] class ConnectionCache(Protocol): - async def invalidate(self, key: ConnectionKey) -> None: ... + async def invalidate(self, key: str) -> None: ... diff --git a/mrok/proxy/ziti.py b/mrok/proxy/ziti.py index 55b95a0..8720875 100644 --- a/mrok/proxy/ziti.py +++ b/mrok/proxy/ziti.py @@ -1,22 +1,13 @@ -"""Ziti-backed connection manager for the proxy. - -This manager owns creation of connections via an OpenZiti context, wraps -streams to observe IO errors, evicts idle entries, and serializes creation -per-key. -""" - import asyncio +import contextlib import logging -import time from pathlib import Path -# typing imports intentionally minimized import openziti +from aiocache import Cache -from mrok.http.types import StreamReader, StreamWriter -from mrok.proxy.dataclasses import CachedStreamEntry from mrok.proxy.streams import CachedStreamReader, CachedStreamWriter -from mrok.proxy.types import CachedStream, ConnectionKey +from mrok.proxy.types import StreamPair logger = logging.getLogger("mrok.proxy") @@ -27,147 +18,100 @@ def __init__( identity_file: str | Path, ziti_timeout_ms: int = 10000, ttl_seconds: float = 60.0, - purge_interval: float = 10.0, + cleanup_interval: float = 10.0, ): - self._identity_file = identity_file - self._ziti_ctx = None - self._ziti_timeout_ms = ziti_timeout_ms - self._ttl = float(ttl_seconds) - self._purge_interval = float(purge_interval) - self._cache: dict[ConnectionKey, CachedStreamEntry] = {} - self._lock = asyncio.Lock() - self._in_progress: dict[ConnectionKey, asyncio.Lock] = {} - self._purge_task: asyncio.Task | None = None - - async def get(self, target: str) -> tuple[StreamReader, StreamWriter] | tuple[None, None]: - head, _, tail = target.partition(".") - terminator = target if head and tail else "" - service = tail if tail else head - r, w = await self._get_or_create_key((service, terminator)) - return r, w - - async def invalidate(self, key: ConnectionKey) -> None: - async with self._lock: - item = self._cache.pop(key, None) - if item is None: - return - await self._close_writer(item.writer) + self.identity_file = identity_file + self.ziti_timeout_ms = ziti_timeout_ms + self.ttl_seconds = ttl_seconds + self.cleanup_interval = cleanup_interval + + self.cache = Cache(Cache.MEMORY) + + self._active_pairs: dict[str, StreamPair] = {} + + self._cleanup_task: asyncio.Task | None = None + self._ziti_ctx: openziti.context.ZitiContext | None = None + + async def create_stream_pair(self, key: str) -> StreamPair: + if not self._ziti_ctx: + raise Exception("ZitiConnectionManager is not started") + sock = self._ziti_ctx.connect(key) + orig_reader, orig_writer = await asyncio.open_connection(sock=sock) + + reader = CachedStreamReader(orig_reader, key, self) + writer = CachedStreamWriter(orig_writer, key, self) + return (reader, writer) + + async def get_or_create(self, key: str) -> StreamPair: + pair = await self.cache.get(key) + + if pair: + logger.info(f"return cached connection for {key}") + await self.cache.set(key, pair, ttl=self.ttl_seconds) + self._active_pairs[key] = pair + return pair + + pair = await self.create_stream_pair(key) + await self.cache.set(key, pair, ttl=self.ttl_seconds) + self._active_pairs[key] = pair + logger.info(f"return new connection for {key}") + return pair + + async def invalidate(self, key: str) -> None: + logger.info(f"invalidating connection for {key}") + pair = await self.cache.get(key) + if pair: + await self._close_pair(pair) + + await self.cache.delete(key) + self._active_pairs.pop(key, None) async def start(self) -> None: + if self._cleanup_task is None: + self._cleanup_task = asyncio.create_task(self._periodic_cleanup()) if self._ziti_ctx is None: - ctx, err = openziti.load(str(self._identity_file), timeout=self._ziti_timeout_ms) + ctx, err = openziti.load(str(self.identity_file), timeout=self.ziti_timeout_ms) if err != 0: raise Exception(f"Cannot create a Ziti context from the identity file: {err}") self._ziti_ctx = ctx - if self._purge_task is None: - self._purge_task = asyncio.create_task(self._purge_loop()) - logger.info("Ziti connection manager started") async def stop(self) -> None: - if self._purge_task is not None: - self._purge_task.cancel() - try: - await self._purge_task - except asyncio.CancelledError: - logger.debug("Purge task was cancelled") - except Exception as e: - logger.warning(f"An error occurred stopping the purge task: {e}") - self._purge_task = None - logger.info("Ziti connection manager stopped") - - async with self._lock: - items = list(self._cache.items()) - self._cache.clear() - - for _, item in items: - await self._close_writer(item.writer) - - async def _purge_loop(self) -> None: + if self._cleanup_task: + self._cleanup_task.cancel() + with contextlib.suppress(Exception): + await self._cleanup_task + + for pair in list(self._active_pairs.values()): + await self._close_pair(pair) + + self._active_pairs.clear() + await self.cache.clear() + openziti.shutdown() + + @staticmethod + async def _close_pair(pair: StreamPair) -> None: + reader, writer = pair + writer.close() + with contextlib.suppress(Exception): + await writer.wait_closed() + + async def _periodic_cleanup(self) -> None: try: while True: - await asyncio.sleep(self._purge_interval) - await self._purge_once() + await asyncio.sleep(self.cleanup_interval) + await self._cleanup_once() except asyncio.CancelledError: return - async def _purge_once(self) -> None: - to_close: list[tuple[StreamReader, StreamWriter]] = [] - async with self._lock: - now = time.time() - for key, item in list(self._cache.items()): - if now - item.last_access > self._ttl: - to_close.append((item.reader, item.writer)) - del self._cache[key] - - for _, writer in to_close: - writer.close() - await self._close_writer(writer) + async def _cleanup_once(self) -> None: + # Keys currently stored in aiocache + keys_in_cache = set(await self.cache.keys()) + # Keys we think are alive + known_keys = set(self._active_pairs.keys()) - def _is_writer_closed(self, writer: StreamWriter) -> bool: - return writer.transport.is_closing() + expired_keys = known_keys - keys_in_cache - async def _close_writer(self, writer: StreamWriter) -> None: - writer.close() - try: - await writer.wait_closed() - except Exception as e: - logger.debug(f"Error closing writer: {e}") - - async def _get_or_create_key(self, key: ConnectionKey) -> CachedStream: - """Internal: create or return a cached wrapped pair for the concrete key.""" - await self._purge_once() - to_close = None - async with self._lock: - if key in self._cache: - now = time.time() - item = self._cache[key] - reader, writer = item.reader, item.writer - if not self._is_writer_closed(writer) and not reader.at_eof(): - self._cache[key] = CachedStreamEntry(reader, writer, now) - return reader, writer - to_close = writer - del self._cache[key] - - lock = self._in_progress.get(key) - if lock is None: - lock = asyncio.Lock() - self._in_progress[key] = lock - - if to_close: - await self._close_writer(to_close) - - async with lock: - try: - # # double-check cache after acquiring the per-key lock - # async with self._lock: - # now = time.time() - # if key in self._cache: - # r, w, _ = self._cache[key] - # if not self._is_writer_closed(w) and not r.at_eof(): - # self._cache[key] = (r, w, now) - # return r, w - - # perform creation via ziti context - extension, instance = key - logger.info(f"Create connection to {extension}: {instance}") - # loop = asyncio.get_running_loop() - # sock = await loop.run_in_executor(None, self._ziti_ctx.connect, - # extension, instance) - if instance: - sock = self._ziti_ctx.connect( - extension, terminator=instance - ) # , terminator=instance) - else: - sock = self._ziti_ctx.connect(extension) - orig_reader, orig_writer = await asyncio.open_connection(sock=sock) - - reader = CachedStreamReader(orig_reader, key, self) - writer = CachedStreamWriter(orig_writer, key, self) - - async with self._lock: - self._cache[key] = CachedStreamEntry(reader, writer, time.time()) - - return reader, writer - finally: - async with self._lock: - self._in_progress.pop(key, None) + for key in expired_keys: + pair = self._active_pairs.pop(key, None) + if pair: + await self._close_pair(pair) diff --git a/pyproject.toml b/pyproject.toml index 38330ec..452e8c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ authors = [ license = { file = "LICENSE.txt" } requires-python = ">=3.12,<4" dependencies = [ + "aiocache>=0.12.3,<0.13.0", "asn1crypto>=1.5.1,<2.0.0", "cryptography>=45.0.7,<46.0.0", "dynaconf>=3.2.11,<4.0.0", diff --git a/uv.lock b/uv.lock index 66e9760..5aece5a 100644 --- a/uv.lock +++ b/uv.lock @@ -2,6 +2,15 @@ version = 1 revision = 3 requires-python = ">=3.12, <4" +[[package]] +name = "aiocache" +version = "0.12.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7a/64/b945b8025a9d1e6e2138845f4022165d3b337f55f50984fbc6a4c0a1e355/aiocache-0.12.3.tar.gz", hash = "sha256:f528b27bf4d436b497a1d0d1a8f59a542c153ab1e37c3621713cb376d44c4713", size = 132196, upload-time = "2024-09-25T13:20:23.823Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/37/d7/15d67e05b235d1ed8c3ce61688fe4d84130e72af1657acadfaac3479f4cf/aiocache-0.12.3-py2.py3-none-any.whl", hash = "sha256:889086fc24710f431937b87ad3720a289f7fc31c4fd8b68e9f918b9bacd8270d", size = 28199, upload-time = "2024-09-25T13:20:22.688Z" }, +] + [[package]] name = "aiohappyeyeballs" version = "2.6.1" @@ -1418,6 +1427,7 @@ name = "mrok" version = "0.0.0.dev0" source = { editable = "." } dependencies = [ + { name = "aiocache" }, { name = "asn1crypto" }, { name = "cryptography" }, { name = "dynaconf" }, @@ -1463,6 +1473,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "aiocache", specifier = ">=0.12.3,<0.13.0" }, { name = "asn1crypto", specifier = ">=1.5.1,<2.0.0" }, { name = "cryptography", specifier = ">=45.0.7,<46.0.0" }, { name = "dynaconf", specifier = ">=3.2.11,<4.0.0" },