Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 14 additions & 22 deletions mrok/proxy/app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import os
from pathlib import Path

from mrok.conf import get_settings
Expand Down Expand Up @@ -35,37 +34,30 @@ def __init__(
self._conn_manager = ZitiConnectionManager(
identity_file,
ttl_seconds=ziti_connection_ttl_seconds,
purge_interval=ziti_conn_cache_purge_interval_seconds,
cleanup_interval=ziti_conn_cache_purge_interval_seconds,
)

def get_target_from_header(self, name: str, headers: dict[str, str]) -> str:
header_value = headers.get(name)
if not header_value:
raise ProxyError(
f"Header {name} not found!",
)
if ":" in header_value:
header_value, _ = header_value.split(":", 1)
if not header_value.endswith(self._proxy_wildcard_domain):
raise ProxyError(f"Unexpected value for {name} header: `{header_value}`.")

return header_value[: -len(self._proxy_wildcard_domain)]
def get_target_from_header(self, headers: dict[str, str], name: str) -> str | None:
header_value = headers.get(name, "")
if self._proxy_wildcard_domain in header_value:
if ":" in header_value:
header_value, _ = header_value.split(":", 1)
return header_value[: -len(self._proxy_wildcard_domain)]

def get_target_name(self, headers: dict[str, str]) -> str:
try:
return self.get_target_from_header("x-forwared-for", headers)
except ProxyError as pe:
logger.warning(pe)
return self.get_target_from_header("host", headers)
target = self.get_target_from_header(headers, "x-forwarded-host")
if not target:
target = self.get_target_from_header(headers, "host")
if not target:
raise ProxyError("Neither Host nor X-Forwarded-Host contain a valid target name")
return target

async def startup(self):
setup_logging(get_settings())
await self._conn_manager.start()
logger.info(f"Proxy app startup completed: {os.getpid()}")

async def shutdown(self):
await self._conn_manager.stop()
logger.info(f"Proxy app shutdown completed: {os.getpid()}")

async def select_backend(
self,
Expand All @@ -74,4 +66,4 @@ async def select_backend(
) -> tuple[StreamReader, StreamWriter] | tuple[None, None]:
target_name = self.get_target_name(headers)

return await self._conn_manager.get(target_name)
return await self._conn_manager.get_or_create(target_name)
6 changes: 3 additions & 3 deletions mrok/proxy/streams.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import asyncio

from mrok.proxy.types import ConnectionCache, ConnectionKey
from mrok.proxy.types import ConnectionCache


class CachedStreamReader:
def __init__(
self,
reader: asyncio.StreamReader,
key: ConnectionKey,
key: str,
manager: ConnectionCache,
):
self._reader = reader
Expand Down Expand Up @@ -77,7 +77,7 @@ class CachedStreamWriter:
def __init__(
self,
writer: asyncio.StreamWriter,
key: ConnectionKey,
key: str,
manager: ConnectionCache,
):
self._writer = writer
Expand Down
5 changes: 2 additions & 3 deletions mrok/proxy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@

from mrok.http.types import StreamReader, StreamWriter

ConnectionKey = tuple[str, str | None]
CachedStream = tuple[StreamReader, StreamWriter]
StreamPair = tuple[StreamReader, StreamWriter]


class ConnectionCache(Protocol):
async def invalidate(self, key: ConnectionKey) -> None: ...
async def invalidate(self, key: str) -> None: ...
224 changes: 84 additions & 140 deletions mrok/proxy/ziti.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,13 @@
"""Ziti-backed connection manager for the proxy.

This manager owns creation of connections via an OpenZiti context, wraps
streams to observe IO errors, evicts idle entries, and serializes creation
per-key.
"""

import asyncio
import contextlib
import logging
import time
from pathlib import Path

# typing imports intentionally minimized
import openziti
from aiocache import Cache

from mrok.http.types import StreamReader, StreamWriter
from mrok.proxy.dataclasses import CachedStreamEntry
from mrok.proxy.streams import CachedStreamReader, CachedStreamWriter
from mrok.proxy.types import CachedStream, ConnectionKey
from mrok.proxy.types import StreamPair

logger = logging.getLogger("mrok.proxy")

Expand All @@ -27,147 +18,100 @@
identity_file: str | Path,
ziti_timeout_ms: int = 10000,
ttl_seconds: float = 60.0,
purge_interval: float = 10.0,
cleanup_interval: float = 10.0,
):
self._identity_file = identity_file
self._ziti_ctx = None
self._ziti_timeout_ms = ziti_timeout_ms
self._ttl = float(ttl_seconds)
self._purge_interval = float(purge_interval)
self._cache: dict[ConnectionKey, CachedStreamEntry] = {}
self._lock = asyncio.Lock()
self._in_progress: dict[ConnectionKey, asyncio.Lock] = {}
self._purge_task: asyncio.Task | None = None

async def get(self, target: str) -> tuple[StreamReader, StreamWriter] | tuple[None, None]:
head, _, tail = target.partition(".")
terminator = target if head and tail else ""
service = tail if tail else head
r, w = await self._get_or_create_key((service, terminator))
return r, w

async def invalidate(self, key: ConnectionKey) -> None:
async with self._lock:
item = self._cache.pop(key, None)
if item is None:
return
await self._close_writer(item.writer)
self.identity_file = identity_file
self.ziti_timeout_ms = ziti_timeout_ms
self.ttl_seconds = ttl_seconds
self.cleanup_interval = cleanup_interval

self.cache = Cache(Cache.MEMORY)

self._active_pairs: dict[str, StreamPair] = {}

self._cleanup_task: asyncio.Task | None = None
self._ziti_ctx: openziti.context.ZitiContext | None = None

async def create_stream_pair(self, key: str) -> StreamPair:
if not self._ziti_ctx:
raise Exception("ZitiConnectionManager is not started")

Check warning on line 37 in mrok/proxy/ziti.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Replace this generic exception class with a more specific one.

See more on https://sonarcloud.io/project/issues?id=softwareone-platform_mrok&issues=AZsNDYfDa9oHRE6yGOPD&open=AZsNDYfDa9oHRE6yGOPD&pullRequest=28
sock = self._ziti_ctx.connect(key)
orig_reader, orig_writer = await asyncio.open_connection(sock=sock)

reader = CachedStreamReader(orig_reader, key, self)
writer = CachedStreamWriter(orig_writer, key, self)
return (reader, writer)

async def get_or_create(self, key: str) -> StreamPair:
pair = await self.cache.get(key)

if pair:
logger.info(f"return cached connection for {key}")
await self.cache.set(key, pair, ttl=self.ttl_seconds)
self._active_pairs[key] = pair
return pair

pair = await self.create_stream_pair(key)
await self.cache.set(key, pair, ttl=self.ttl_seconds)
self._active_pairs[key] = pair
logger.info(f"return new connection for {key}")
return pair

async def invalidate(self, key: str) -> None:
logger.info(f"invalidating connection for {key}")
pair = await self.cache.get(key)
if pair:
await self._close_pair(pair)

await self.cache.delete(key)
self._active_pairs.pop(key, None)

async def start(self) -> None:
if self._cleanup_task is None:
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
if self._ziti_ctx is None:
ctx, err = openziti.load(str(self._identity_file), timeout=self._ziti_timeout_ms)
ctx, err = openziti.load(str(self.identity_file), timeout=self.ziti_timeout_ms)
if err != 0:
raise Exception(f"Cannot create a Ziti context from the identity file: {err}")
self._ziti_ctx = ctx
if self._purge_task is None:
self._purge_task = asyncio.create_task(self._purge_loop())
logger.info("Ziti connection manager started")

async def stop(self) -> None:
if self._purge_task is not None:
self._purge_task.cancel()
try:
await self._purge_task
except asyncio.CancelledError:
logger.debug("Purge task was cancelled")
except Exception as e:
logger.warning(f"An error occurred stopping the purge task: {e}")
self._purge_task = None
logger.info("Ziti connection manager stopped")

async with self._lock:
items = list(self._cache.items())
self._cache.clear()

for _, item in items:
await self._close_writer(item.writer)

async def _purge_loop(self) -> None:
if self._cleanup_task:
self._cleanup_task.cancel()
with contextlib.suppress(Exception):
await self._cleanup_task

for pair in list(self._active_pairs.values()):
await self._close_pair(pair)

self._active_pairs.clear()
await self.cache.clear()
openziti.shutdown()

@staticmethod
async def _close_pair(pair: StreamPair) -> None:
reader, writer = pair

Check warning on line 93 in mrok/proxy/ziti.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Replace the unused local variable "reader" with "_".

See more on https://sonarcloud.io/project/issues?id=softwareone-platform_mrok&issues=AZsNDYfDa9oHRE6yGOPE&open=AZsNDYfDa9oHRE6yGOPE&pullRequest=28
writer.close()
with contextlib.suppress(Exception):
await writer.wait_closed()

async def _periodic_cleanup(self) -> None:
try:
while True:
await asyncio.sleep(self._purge_interval)
await self._purge_once()
await asyncio.sleep(self.cleanup_interval)
await self._cleanup_once()
except asyncio.CancelledError:
return

async def _purge_once(self) -> None:
to_close: list[tuple[StreamReader, StreamWriter]] = []
async with self._lock:
now = time.time()
for key, item in list(self._cache.items()):
if now - item.last_access > self._ttl:
to_close.append((item.reader, item.writer))
del self._cache[key]

for _, writer in to_close:
writer.close()
await self._close_writer(writer)
async def _cleanup_once(self) -> None:
# Keys currently stored in aiocache
keys_in_cache = set(await self.cache.keys())
# Keys we think are alive
known_keys = set(self._active_pairs.keys())

def _is_writer_closed(self, writer: StreamWriter) -> bool:
return writer.transport.is_closing()
expired_keys = known_keys - keys_in_cache

async def _close_writer(self, writer: StreamWriter) -> None:
writer.close()
try:
await writer.wait_closed()
except Exception as e:
logger.debug(f"Error closing writer: {e}")

async def _get_or_create_key(self, key: ConnectionKey) -> CachedStream:
"""Internal: create or return a cached wrapped pair for the concrete key."""
await self._purge_once()
to_close = None
async with self._lock:
if key in self._cache:
now = time.time()
item = self._cache[key]
reader, writer = item.reader, item.writer
if not self._is_writer_closed(writer) and not reader.at_eof():
self._cache[key] = CachedStreamEntry(reader, writer, now)
return reader, writer
to_close = writer
del self._cache[key]

lock = self._in_progress.get(key)
if lock is None:
lock = asyncio.Lock()
self._in_progress[key] = lock

if to_close:
await self._close_writer(to_close)

async with lock:
try:
# # double-check cache after acquiring the per-key lock
# async with self._lock:
# now = time.time()
# if key in self._cache:
# r, w, _ = self._cache[key]
# if not self._is_writer_closed(w) and not r.at_eof():
# self._cache[key] = (r, w, now)
# return r, w

# perform creation via ziti context
extension, instance = key
logger.info(f"Create connection to {extension}: {instance}")
# loop = asyncio.get_running_loop()
# sock = await loop.run_in_executor(None, self._ziti_ctx.connect,
# extension, instance)
if instance:
sock = self._ziti_ctx.connect(
extension, terminator=instance
) # , terminator=instance)
else:
sock = self._ziti_ctx.connect(extension)
orig_reader, orig_writer = await asyncio.open_connection(sock=sock)

reader = CachedStreamReader(orig_reader, key, self)
writer = CachedStreamWriter(orig_writer, key, self)

async with self._lock:
self._cache[key] = CachedStreamEntry(reader, writer, time.time())

return reader, writer
finally:
async with self._lock:
self._in_progress.pop(key, None)
for key in expired_keys:
pair = self._active_pairs.pop(key, None)
if pair:
await self._close_pair(pair)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ authors = [
license = { file = "LICENSE.txt" }
requires-python = ">=3.12,<4"
dependencies = [
"aiocache>=0.12.3,<0.13.0",
"asn1crypto>=1.5.1,<2.0.0",
"cryptography>=45.0.7,<46.0.0",
"dynaconf>=3.2.11,<4.0.0",
Expand Down
11 changes: 11 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading