Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ jobs:

- name: Create version tag
shell: bash
run: echo "tag=docker.elaad.io/elaadnl/bl-reference-implementation:$(git show -s --format="%ct-%h" $GITHUB_SHA)" >> $GITHUB_ENV
run: echo "tag=docker.elaad.io/elaadnl/ditm-openadr-bl:$(git show -s --format="%ct-%h" $GITHUB_SHA)" >> $GITHUB_ENV
- name: Latest tag on main branch
if: github.ref == 'refs/heads/main'
run: echo "tag_main=,docker.elaad.io/elaadnl/bl-reference-implementation:latest" >> $GITHUB_ENV
run: echo "tag_main=,docker.elaad.io/elaadnl/ditm-openadr-bl:latest" >> $GITHUB_ENV

- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # v3.11.1
Expand Down
86 changes: 72 additions & 14 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 8 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,18 @@ authors = [
readme = "README.md"
requires-python = ">=3.12, <4"
dependencies = [
"openadr3-client-gac-compliance (>=1.4.0,<2.0.0)",
"openadr3-client-gac-compliance (==2.0.0)",
"python-decouple (>=3.8,<4.0)",
"influxdb-client[async] (>=1.49.0,<2.0.0)",
"ruff (>=0.12.4,<0.13.0)",
"mypy (>=1.17.0,<2.0.0)",
"pytest (>=8.4.1,<9.0.0)",
"azure-functions (>=1.23.0,<2.0.0)",
"openadr3-client (>=0.0.7,<0.0.8)"
"openadr3-client (==0.0.11)",
"requests (>=2.32.5,<3.0.0)",
"requests-oauthlib (>=2.0.0,<3.0.0)",
"holidays (>=0.83,<0.84)",
"types-requests (>=2.32.4.20250913,<3.0.0.0)",
]

[build-system]
Expand All @@ -28,6 +32,8 @@ package-mode = false

[tool.poetry.group.dev.dependencies]
pytest-cov = "^6.2.1"
types-oauthlib = "^3.3.0.20250822"
types-requests-oauthlib = "^2.0.0.20250809"

[[tool.mypy.overrides]]
module = ["decouple", "openadr3_client_gac_compliance"]
Expand Down
2 changes: 1 addition & 1 deletion src/application/generate_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _generate_capacity_limitation_intervals(
type=EventPayloadType.IMPORT_CAPACITY_LIMIT,
values=(
predicted_grid_asset_loads.flex_capacity_required(max_capacity)
or 4,
or 22,
),
),
),
Expand Down
22 changes: 22 additions & 0 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,25 @@
# INFLUXDB parameters (secrets)
INFLUXDB_TOKEN = config("INFLUXDB_TOKEN")
INFLUXDB_URL = config("INFLUXDB_URL")

PREDICTED_TRAFO_LOAD_BUCKET = config(
"PREDICTED_TRAFO_LOAD_BUCKET", default="ditm_model_output"
)
STANDARD_PROFILES_BUCKET_NAME = config(
"STANDARD_PROFILES_BUCKET_NAME", default="ditm_standard_profiles"
)
DALIDATA_BUCKET_NAME = config("DALIDATA_BUCKET_NAME", default="dalidata")

# External services URLs
WEATHER_FORECAST_API_URL = config("WEATHER_FORECAST_API_URL")

# Authentication to Azure ML managed endpoint for prediction model
DITM_MODEL_API_URL = config("DITM_MODEL_API_URL")
DITM_MODEL_API_CLIENT_ID = config("DITM_MODEL_API_CLIENT_ID")
DITM_MODEL_API_CLIENT_SECRET = config("DITM_MODEL_API_CLIENT_SECRET")
DITM_MODEL_API_TOKEN_URL = config("DITM_MODEL_API_TOKEN_URL")

OAUTH_CLIENT_ID = config("OAUTH_CLIENT_ID")
OAUTH_CLIENT_SECRET = config("OAUTH_CLIENT_SECRET")
OAUTH_TOKEN_ENDPOINT = config("OAUTH_TOKEN_ENDPOINT")
OAUTH_SCOPES = config("OAUTH_SCOPES")
56 changes: 56 additions & 0 deletions src/infrastructure/_auth/http/authenticated_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Implementation of a HTTP session which has an associated access token that is send to every request."""

from typing import Optional
from requests import PreparedRequest, Session
from requests.auth import AuthBase

from src.config import (
DITM_MODEL_API_CLIENT_ID,
DITM_MODEL_API_CLIENT_SECRET,
DITM_MODEL_API_TOKEN_URL,
)
from src.infrastructure._auth.token_manager import (
OAuthTokenManager,
OAuthTokenManagerConfig,
)


class _BearerAuth(AuthBase):
"""AuthBase implementation that includes a bearer token in all requests."""

def __init__(self, token_manager: OAuthTokenManager) -> None:
self._token_manager = token_manager

def __call__(self, r: PreparedRequest) -> PreparedRequest:
"""
Perform the request.

Adds the bearer token to the 'Authorization' request header before the call is made.
If the 'Authorization' was already present, it is replaced.
"""
# The token manager handles caching internally, so we can safely invoke this
# for each request.
r.headers["Authorization"] = "Bearer " + self._token_manager.get_access_token()
return r


class _BearerAuthenticatedSession(Session):
"""Session that includes a bearer token in all requests made through it."""

def __init__(
self,
token_manager: Optional[OAuthTokenManager] = None,
scopes: Optional[list[str]] = None,
) -> None:
super().__init__()
if not token_manager:
token_manager = OAuthTokenManager(
OAuthTokenManagerConfig(
client_id=DITM_MODEL_API_CLIENT_ID,
client_secret=DITM_MODEL_API_CLIENT_SECRET,
token_url=DITM_MODEL_API_TOKEN_URL,
scopes=scopes,
audience=None,
)
)
self.auth = _BearerAuth(token_manager)
88 changes: 88 additions & 0 deletions src/infrastructure/_auth/token_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from threading import Lock

from oauthlib.oauth2 import BackendApplicationClient
from requests_oauthlib import OAuth2Session

from src.logger import logger


@dataclass
class OAuthTokenManagerConfig:
client_id: str
client_secret: str
token_url: str
scopes: list[str] | None
audience: str | None


class OAuthTokenManager:
"""An OAuth token manager responsible for the retrieval and caching of access tokens."""

def __init__(self, config: OAuthTokenManagerConfig) -> None:
self.client = BackendApplicationClient(
client_id=config.client_id,
scope=" ".join(config.scopes) if config.scopes is not None else None,
)
self.oauth = OAuth2Session(client=self.client)
self.token_url = config.token_url
self.client_secret = config.client_secret
self.audience = config.audience
if self.token_url is None:
msg = "token_url is required"
raise ValueError(msg)

if self.client_secret is None:
msg = "client_secret is required"
raise ValueError(msg)

self._lock = Lock()
self._cached_token: tuple[datetime, str] | None = None

def get_access_token(self) -> str:
"""
Retrieves an access token from the token manager.

If a cached token is present in the token manager, this token is returned.
If no cached token is present, a new token is fetched, cached and returned.

Returns:
str: The access token.

"""
with self._lock:
if self._cached_token:
expiration_time, token = self._cached_token

if expiration_time > datetime.now(tz=UTC):
return token

# If we reach here, the token has reached its expiration time.
# Remove the token and fetch a new one.
self._cached_token = None

return self._get_new_access_token()

def _get_new_access_token(self) -> str:
token_response = self.oauth.fetch_token(
token_url=self.token_url,
client_secret=self.client_secret,
audience=self.audience,
)

# Calculate expiration time (half of token lifetime)
expires_in_seconds = token_response.get("expires_in", 3600)
expiration_time = datetime.now(tz=UTC) + timedelta(
seconds=expires_in_seconds // 2
)

access_token = token_response.get("access_token")

if not access_token:
logger.error("OAuthTokenManager - access_token not present in response")
exc_msg = "Access token was not present in token response"
raise ValueError(exc_msg)

self._cached_token = (expiration_time, access_token)
return access_token
Loading
Loading