From 0257fc792e722775f0af83b61484696207de6dfd Mon Sep 17 00:00:00 2001 From: Maks Osowski Date: Wed, 5 Mar 2025 18:08:08 +0100 Subject: [PATCH 1/2] feat(flagd): Add features to customize auth to Sync API servers Signed-off-by: Maks Osowski --- .../contrib/provider/flagd/config.py | 16 ++++++++++++++++ .../contrib/provider/flagd/provider.py | 7 +++++++ .../process/connector/grpc_watcher.py | 19 +++++++++++++++++-- 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/config.py b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/config.py index ac6134a7..86d1cc9a 100644 --- a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/config.py +++ b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/config.py @@ -1,4 +1,5 @@ import dataclasses +import grpc import os import typing from enum import Enum @@ -45,9 +46,11 @@ class CacheType(Enum): ENV_VAR_RETRY_BACKOFF_MAX_MS = "FLAGD_RETRY_BACKOFF_MAX_MS" ENV_VAR_RETRY_GRACE_PERIOD_SECONDS = "FLAGD_RETRY_GRACE_PERIOD" ENV_VAR_SELECTOR = "FLAGD_SOURCE_SELECTOR" +ENV_VAR_PROVIDER_ID = "FLAGD_SOURCE_PROVIDER_ID" ENV_VAR_STREAM_DEADLINE_MS = "FLAGD_STREAM_DEADLINE_MS" ENV_VAR_TLS = "FLAGD_TLS" ENV_VAR_TLS_CERT = "FLAGD_SERVER_CERT_PATH" +ENV_VAR_DEFAULT_AUTHORITY = "FLAGD_DEFAULT_AUTHORITY" T = typing.TypeVar("T") @@ -81,6 +84,7 @@ def __init__( # noqa: PLR0913 port: typing.Optional[int] = None, tls: typing.Optional[bool] = None, selector: typing.Optional[str] = None, + provider_id: typing.Optional[str] = None, resolver: typing.Optional[ResolverType] = None, offline_flag_source_path: typing.Optional[str] = None, offline_poll_interval_ms: typing.Optional[int] = None, @@ -93,6 +97,8 @@ def __init__( # noqa: PLR0913 cache: typing.Optional[CacheType] = None, max_cache_size: typing.Optional[int] = None, cert_path: typing.Optional[str] = None, + default_authority: typing.Optional[str] = None, + channel_credentials: typing.Optional[grpc.ChannelCredentials] = None, ): self.host = env_or_default(ENV_VAR_HOST, DEFAULT_HOST) if host is None else host @@ -227,3 +233,13 @@ def __init__( # noqa: PLR0913 self.selector = ( env_or_default(ENV_VAR_SELECTOR, None) if selector is None else selector ) + + self.provider_id = ( + env_or_default(ENV_VAR_PROVIDER_ID, None) if provider_id is None else provider_id + ) + + self.default_authority = ( + env_or_default(ENV_VAR_DEFAULT_AUTHORITY, None) if default_authority is None else default_authority + ) + + self.channel_credentials = channel_credentials diff --git a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/provider.py b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/provider.py index 83b03897..ef36c925 100644 --- a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/provider.py +++ b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/provider.py @@ -21,6 +21,7 @@ # provider.initialise(schema="https",endpoint="example.com",port=1234,timeout=10) """ +import grpc import typing import warnings @@ -47,6 +48,7 @@ def __init__( # noqa: PLR0913 timeout: typing.Optional[int] = None, retry_backoff_ms: typing.Optional[int] = None, selector: typing.Optional[str] = None, + provider_id: typing.Optional[str] = None, resolver_type: typing.Optional[ResolverType] = None, offline_flag_source_path: typing.Optional[str] = None, stream_deadline_ms: typing.Optional[int] = None, @@ -56,6 +58,8 @@ def __init__( # noqa: PLR0913 retry_backoff_max_ms: typing.Optional[int] = None, retry_grace_period: typing.Optional[int] = None, cert_path: typing.Optional[str] = None, + default_authority: typing.Optional[str] = None, + grpc_credentials: typing.Optional[grpc.ChannelCredentials] = None, ): """ Create an instance of the FlagdProvider @@ -88,6 +92,7 @@ def __init__( # noqa: PLR0913 retry_backoff_max_ms=retry_backoff_max_ms, retry_grace_period=retry_grace_period, selector=selector, + provider_id=provider_id, resolver=resolver_type, offline_flag_source_path=offline_flag_source_path, stream_deadline_ms=stream_deadline_ms, @@ -95,6 +100,8 @@ def __init__( # noqa: PLR0913 cache=cache, max_cache_size=max_cache_size, cert_path=cert_path, + default_authority=default_authority, + channel_credentials=grpc_credentials, ) self.resolver = self.setup_resolver() diff --git a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/resolvers/process/connector/grpc_watcher.py b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/resolvers/process/connector/grpc_watcher.py index 138a5ddb..8e1466a1 100644 --- a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/resolvers/process/connector/grpc_watcher.py +++ b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/resolvers/process/connector/grpc_watcher.py @@ -41,6 +41,7 @@ def __init__( self.streamline_deadline_seconds = config.stream_deadline_ms * 0.001 self.deadline = config.deadline_ms * 0.001 self.selector = config.selector + self.provider_id = config.provider_id self.emit_provider_ready = emit_provider_ready self.emit_provider_error = emit_provider_error self.emit_provider_stale = emit_provider_stale @@ -60,7 +61,17 @@ def _generate_channel(self, config: Config) -> grpc.Channel: ("grpc.max_reconnect_backoff_ms", config.retry_backoff_max_ms), ("grpc.min_reconnect_backoff_ms", config.stream_deadline_ms), ] - if config.tls: + if config.default_authority is not None: + options.append(("grpc.default_authority", config.default_authority)) + + if config.channel_credentials is not None: + channel_args = { + "options": options, + "credentials": config.channel_credentials, + } + channel = grpc.secure_channel(target, **channel_args) + + elif config.tls: channel_args = { "options": options, "credentials": grpc.ssl_channel_credentials(), @@ -153,7 +164,11 @@ def listen(self) -> None: if self.streamline_deadline_seconds > 0 else {} ) - request_args = {"selector": self.selector} if self.selector is not None else {} + request_args = {} + if self.selector is not None: + request_args["selector"] = self.selector + if self.provider_id is not None: + request_args["provider_id"] = self.provider_id while self.active: try: From 54428d9855806e44f4aad60a0151414027e3c3a5 Mon Sep 17 00:00:00 2001 From: Maks Osowski Date: Thu, 6 Mar 2025 16:51:57 +0100 Subject: [PATCH 2/2] chore(flagd): Fix var names, type checking, and linting Signed-off-by: Maks Osowski --- .../src/openfeature/contrib/provider/flagd/config.py | 11 ++++++++--- .../openfeature/contrib/provider/flagd/provider.py | 7 ++++--- .../flagd/resolvers/process/connector/grpc_watcher.py | 2 +- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/config.py b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/config.py index 86d1cc9a..59b86e2a 100644 --- a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/config.py +++ b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/config.py @@ -1,9 +1,10 @@ import dataclasses -import grpc import os import typing from enum import Enum +import grpc + class ResolverType(Enum): RPC = "rpc" @@ -235,11 +236,15 @@ def __init__( # noqa: PLR0913 ) self.provider_id = ( - env_or_default(ENV_VAR_PROVIDER_ID, None) if provider_id is None else provider_id + env_or_default(ENV_VAR_PROVIDER_ID, None) + if provider_id is None + else provider_id ) self.default_authority = ( - env_or_default(ENV_VAR_DEFAULT_AUTHORITY, None) if default_authority is None else default_authority + env_or_default(ENV_VAR_DEFAULT_AUTHORITY, None) + if default_authority is None + else default_authority ) self.channel_credentials = channel_credentials diff --git a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/provider.py b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/provider.py index ef36c925..0716e0a9 100644 --- a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/provider.py +++ b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/provider.py @@ -21,10 +21,11 @@ # provider.initialise(schema="https",endpoint="example.com",port=1234,timeout=10) """ -import grpc import typing import warnings +import grpc + from openfeature.evaluation_context import EvaluationContext from openfeature.flag_evaluation import FlagResolutionDetails from openfeature.provider import AbstractProvider @@ -59,7 +60,7 @@ def __init__( # noqa: PLR0913 retry_grace_period: typing.Optional[int] = None, cert_path: typing.Optional[str] = None, default_authority: typing.Optional[str] = None, - grpc_credentials: typing.Optional[grpc.ChannelCredentials] = None, + channel_credentials: typing.Optional[grpc.ChannelCredentials] = None, ): """ Create an instance of the FlagdProvider @@ -101,7 +102,7 @@ def __init__( # noqa: PLR0913 max_cache_size=max_cache_size, cert_path=cert_path, default_authority=default_authority, - channel_credentials=grpc_credentials, + channel_credentials=channel_credentials, ) self.resolver = self.setup_resolver() diff --git a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/resolvers/process/connector/grpc_watcher.py b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/resolvers/process/connector/grpc_watcher.py index 8e1466a1..f3e9c4a6 100644 --- a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/resolvers/process/connector/grpc_watcher.py +++ b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/resolvers/process/connector/grpc_watcher.py @@ -55,7 +55,7 @@ def __init__( def _generate_channel(self, config: Config) -> grpc.Channel: target = f"{config.host}:{config.port}" # Create the channel with the service config - options = [ + options: list[tuple[str, typing.Any]] = [ ("grpc.keepalive_time_ms", config.keep_alive_time), ("grpc.initial_reconnect_backoff_ms", config.retry_backoff_ms), ("grpc.max_reconnect_backoff_ms", config.retry_backoff_max_ms),