diff --git a/monitoring/monitorlib/fetch/__init__.py b/monitoring/monitorlib/fetch/__init__.py index 17dcf9d2a1..76ed6fe7a0 100644 --- a/monitoring/monitorlib/fetch/__init__.py +++ b/monitoring/monitorlib/fetch/__init__.py @@ -7,18 +7,23 @@ from dataclasses import dataclass from enum import Enum from http.client import RemoteDisconnected -from typing import Self, TypeVar +from typing import Optional, Self, TypeVar from urllib.parse import urlparse import flask import jwt import requests import urllib3 -from implicitdict import ImplicitDict, Optional, StringBasedDateTime +from implicitdict import ( + ImplicitDict, + StringBasedDateTime, + StringBasedTimeDelta, +) from loguru import logger from monitoring.monitorlib import infrastructure from monitoring.monitorlib.errors import stacktrace_string +from monitoring.monitorlib.infrastructure import AUTHORIZATION_DT from monitoring.monitorlib.rid import RIDVersion @@ -54,6 +59,9 @@ class RequestDescription(ImplicitDict): initiated_at: Optional[StringBasedDateTime] received_at: Optional[StringBasedDateTime] + auth_dt: Optional[StringBasedTimeDelta] + """Amount of time required to obtain authorization before performing the primary query (de minimus or unknown by default).""" + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if "headers" not in self: @@ -124,6 +132,11 @@ def describe_request( "initiated_at": StringBasedDateTime(initiated_at), "headers": headers, } + authorization_dt: datetime.timedelta | None = getattr(req, AUTHORIZATION_DT, None) + if authorization_dt: + kwargs["auth_dt"] = StringBasedTimeDelta( + f"{authorization_dt.total_seconds():.4g}s" + ) body = req.body.decode("utf-8") if req.body else None try: if body: diff --git a/monitoring/monitorlib/infrastructure.py b/monitoring/monitorlib/infrastructure.py index 194f1cefc7..8464f80f9f 100644 --- a/monitoring/monitorlib/infrastructure.py +++ b/monitoring/monitorlib/infrastructure.py @@ -7,6 +7,7 @@ import time import urllib.parse import weakref +from dataclasses import dataclass from enum import Enum import jwt @@ -28,11 +29,20 @@ CLIENT_TIMEOUT = 10 # seconds SOCKET_KEEP_ALIVE_LIMIT = 57 # seconds. +AUTHORIZATION_DT = "authorization_dt" +"""This attribute may be added to a PreparedRequest indicating the timedelta required to obtain authorization""" + AuthSpec = str """Specification for means by which to obtain access tokens.""" +@dataclass +class AdditionalHeaders: + headers: dict[str, str] + token_issuance_seconds: float | None = None + + class AuthAdapter: """Base class for an adapter that add JWTs to requests.""" @@ -44,33 +54,48 @@ def issue_token(self, intended_audience: str, scopes: list[str]) -> str: raise NotImplementedError() - def get_headers(self, url: str, scopes: list[str] | None = None) -> dict[str, str]: + def get_headers( + self, url: str, scopes: list[str] | None = None + ) -> AdditionalHeaders: if scopes is None: scopes = ALL_SCOPES scopes = [s.value if isinstance(s, Enum) else s for s in scopes] intended_audience = urllib.parse.urlparse(url).hostname if not intended_audience: - return {} + return AdditionalHeaders(headers={}) scope_string = " ".join(scopes) if intended_audience not in self._tokens: self._tokens[intended_audience] = {} if scope_string not in self._tokens[intended_audience]: + t0 = time.monotonic() token = self.issue_token(intended_audience, scopes) + dt_s = time.monotonic() - t0 else: token = self._tokens[intended_audience][scope_string] + dt_s = None payload = jwt.decode(token, options={"verify_signature": False}) expires = EPOCH + datetime.timedelta(seconds=payload["exp"]) if datetime.datetime.now(datetime.UTC) > expires - TOKEN_REFRESH_MARGIN: + t0 = time.monotonic() token = self.issue_token(intended_audience, scopes) + dt_s = (dt_s or 0) + (time.monotonic() - t0) self._tokens[intended_audience][scope_string] = token - return {"Authorization": "Bearer " + token} + return AdditionalHeaders( + headers={"Authorization": "Bearer " + token}, token_issuance_seconds=dt_s + ) - def add_headers(self, request: requests.PreparedRequest, scopes: list[str]): + def add_headers( + self, request: requests.PreparedRequest, scopes: list[str] + ) -> AdditionalHeaders: if request.url: - for k, v in self.get_headers(request.url, scopes).items(): + additional_headers = self.get_headers(request.url, scopes) + for k, v in additional_headers.headers.items(): request.headers[k] = v + return additional_headers + else: + return AdditionalHeaders(headers={}) def get_sub(self) -> str | None: """Retrieve `sub` claim from one of the existing tokens""" @@ -182,6 +207,21 @@ def prepare_request(self, request, **kwargs): return super().prepare_request(request, **kwargs) + def add_auth( + self, prepared_request: requests.PreparedRequest, scopes: list[str] | None + ) -> requests.PreparedRequest: + if scopes and self.auth_adapter: + additional_headers = self.auth_adapter.add_headers(prepared_request, scopes) + if additional_headers.token_issuance_seconds: + setattr( + prepared_request, + AUTHORIZATION_DT, + datetime.timedelta( + seconds=additional_headers.token_issuance_seconds + ), + ) + return prepared_request + def adjust_request_kwargs(self, kwargs): if self.auth_adapter: scopes = None @@ -194,14 +234,7 @@ def adjust_request_kwargs(self, kwargs): if scopes is None: scopes = self.default_scopes - def auth( - prepared_request: requests.PreparedRequest, - ) -> requests.PreparedRequest: - if scopes and self.auth_adapter: - self.auth_adapter.add_headers(prepared_request, scopes) - return prepared_request - - kwargs["auth"] = auth + kwargs["auth"] = lambda req: self.add_auth(req, scopes) if "timeout" not in kwargs: kwargs["timeout"] = self.timeout_seconds return kwargs @@ -295,10 +328,8 @@ def adjust_request_kwargs(self, url, method, kwargs): raise ValueError( "All tests must specify auth scope for all session requests. Either specify as an argument for each individual HTTP call, or decorate the test with @default_scope." ) - headers = {} - for k, v in self.auth_adapter.get_headers(url, scopes).items(): - headers[k] = v - kwargs["headers"] = headers + additional_headers = self.auth_adapter.get_headers(url, scopes) + kwargs["headers"] = additional_headers.headers if method == "PUT" and kwargs.get("data"): kwargs["json"] = kwargs["data"] del kwargs["data"] diff --git a/monitoring/uss_qualifier/resources/communications/client_identity.py b/monitoring/uss_qualifier/resources/communications/client_identity.py index 51975995af..a5c3ad9812 100644 --- a/monitoring/uss_qualifier/resources/communications/client_identity.py +++ b/monitoring/uss_qualifier/resources/communications/client_identity.py @@ -55,7 +55,7 @@ def subject(self) -> str: # we force one using the client identify audience and scopes # Trigger a caching initial token request so that adapter.get_sub() will return something - headers = self._adapter.get_headers( + additional_headers = self._adapter.get_headers( f"https://{self.specification.whoami_audience}", [self.specification.whoami_scope], ) @@ -66,7 +66,7 @@ def subject(self) -> str: raise ValueError( f"subject is None, meaning `sub` claim was not found in payload of token, " f"using {type(self._adapter).__name__} requesting {self.specification.whoami_scope} scope " - f"for {self.specification.whoami_audience} audience: {headers['Authorization'][len('Bearer: ') :]}" + f"for {self.specification.whoami_audience} audience: {additional_headers.headers['Authorization'][len('Bearer: ') :]}" ) return sub diff --git a/monitoring/uss_qualifier/scenarios/astm/netrid/common/dss/heavy_traffic_concurrent.py b/monitoring/uss_qualifier/scenarios/astm/netrid/common/dss/heavy_traffic_concurrent.py index b3f0c2c30b..369060f020 100644 --- a/monitoring/uss_qualifier/scenarios/astm/netrid/common/dss/heavy_traffic_concurrent.py +++ b/monitoring/uss_qualifier/scenarios/astm/netrid/common/dss/heavy_traffic_concurrent.py @@ -206,6 +206,7 @@ async def _get_isa(self, isa_id): "GET", url, ) + # TODO: Do not rely on a prepared request that is not actually used in order to create the Query RequestDescription; instead build it from the request actually made prep = self._dss.client.prepare_request(r) t0 = datetime.now(UTC) req_descr = describe_request(prep, t0) @@ -242,6 +243,7 @@ async def _create_isa(self, isa_id): url, json=payload, ) + # TODO: Do not rely on a prepared request that is not actually used in order to create the Query RequestDescription; instead build it from the request actually made prep = self._dss.client.prepare_request(r) t0 = datetime.now(UTC) req_descr = describe_request(prep, t0) @@ -272,6 +274,7 @@ async def _delete_isa(self, isa_id, isa_version): "DELETE", url, ) + # TODO: Do not rely on a prepared request that is not actually used in order to create the Query RequestDescription; instead build it from the request actually made prep = self._dss.client.prepare_request(r) t0 = datetime.now(UTC) req_descr = describe_request(prep, t0) diff --git a/pyproject.toml b/pyproject.toml index f72b26c935..4557ba9c25 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,18 +63,22 @@ default-groups = [] [tool.ruff] target-version = "py313" +extend-exclude = [ + "interfaces/*", + "monitoring/prober/output/*", + "monitoring/mock_uss/output/*", + "monitoring/uss_qualifier/output/*", +] +line-length = 88 +[tool.ruff.lint] # Default + isort + pyupgrade -lint.select = [ +select = [ "E4", "E7", "E9", "F", "I", "UP" ] -extend-exclude = [ - "interfaces/*", - "monitoring/prober/output/*", - "monitoring/mock_uss/output/*", - "monitoring/uss_qualifier/output/*", -] -line-length = 88 +# Explicitly ignore UP045 (Optional[Foo] -> Foo | None) +ignore = ["UP045"] +unfixable = ["UP045"] [tool.basedpyright] typeCheckingMode = "standard" diff --git a/schemas/monitoring/monitorlib/fetch/RequestDescription.json b/schemas/monitoring/monitorlib/fetch/RequestDescription.json index 167ef825a4..2348ab67bc 100644 --- a/schemas/monitoring/monitorlib/fetch/RequestDescription.json +++ b/schemas/monitoring/monitorlib/fetch/RequestDescription.json @@ -7,6 +7,14 @@ "description": "Path to content that replaces the $ref", "type": "string" }, + "auth_dt": { + "description": "Amount of time required to obtain authorization before performing the primary query (de minimus or unknown by default).", + "format": "duration", + "type": [ + "string", + "null" + ] + }, "body": { "type": [ "string",