Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions monitoring/monitorlib/fetch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,23 @@
from dataclasses import dataclass
from enum import Enum
from http.client import RemoteDisconnected
from typing import Self, TypeVar
from typing import Optional, Self, TypeVar
from urllib.parse import urlparse

import flask
import jwt
import requests
import urllib3
from implicitdict import ImplicitDict, Optional, StringBasedDateTime
from implicitdict import (
ImplicitDict,
StringBasedDateTime,
StringBasedTimeDelta,
)
from loguru import logger

from monitoring.monitorlib import infrastructure
from monitoring.monitorlib.errors import stacktrace_string
from monitoring.monitorlib.infrastructure import AUTHORIZATION_DT
from monitoring.monitorlib.rid import RIDVersion


Expand Down Expand Up @@ -54,6 +59,9 @@ class RequestDescription(ImplicitDict):
initiated_at: Optional[StringBasedDateTime]
received_at: Optional[StringBasedDateTime]

auth_dt: Optional[StringBasedTimeDelta]
"""Amount of time required to obtain authorization before performing the primary query (de minimus or unknown by default)."""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

de minimus

TIL :D


def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if "headers" not in self:
Expand Down Expand Up @@ -124,6 +132,11 @@ def describe_request(
"initiated_at": StringBasedDateTime(initiated_at),
"headers": headers,
}
authorization_dt: datetime.timedelta | None = getattr(req, AUTHORIZATION_DT, None)
if authorization_dt:
kwargs["auth_dt"] = StringBasedTimeDelta(
f"{authorization_dt.total_seconds():.4g}s"
)
body = req.body.decode("utf-8") if req.body else None
try:
if body:
Expand Down
65 changes: 48 additions & 17 deletions monitoring/monitorlib/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import time
import urllib.parse
import weakref
from dataclasses import dataclass
from enum import Enum

import jwt
Expand All @@ -28,11 +29,20 @@
CLIENT_TIMEOUT = 10 # seconds
SOCKET_KEEP_ALIVE_LIMIT = 57 # seconds.

AUTHORIZATION_DT = "authorization_dt"
"""This attribute may be added to a PreparedRequest indicating the timedelta required to obtain authorization"""


AuthSpec = str
"""Specification for means by which to obtain access tokens."""


@dataclass
class AdditionalHeaders:
headers: dict[str, str]
token_issuance_seconds: float | None = None


class AuthAdapter:
"""Base class for an adapter that add JWTs to requests."""

Expand All @@ -44,33 +54,48 @@ def issue_token(self, intended_audience: str, scopes: list[str]) -> str:

raise NotImplementedError()

def get_headers(self, url: str, scopes: list[str] | None = None) -> dict[str, str]:
def get_headers(
self, url: str, scopes: list[str] | None = None
) -> AdditionalHeaders:
if scopes is None:
scopes = ALL_SCOPES
scopes = [s.value if isinstance(s, Enum) else s for s in scopes]
intended_audience = urllib.parse.urlparse(url).hostname

if not intended_audience:
return {}
return AdditionalHeaders(headers={})

scope_string = " ".join(scopes)
if intended_audience not in self._tokens:
self._tokens[intended_audience] = {}
if scope_string not in self._tokens[intended_audience]:
t0 = time.monotonic()
token = self.issue_token(intended_audience, scopes)
dt_s = time.monotonic() - t0
else:
token = self._tokens[intended_audience][scope_string]
dt_s = None
payload = jwt.decode(token, options={"verify_signature": False})
expires = EPOCH + datetime.timedelta(seconds=payload["exp"])
if datetime.datetime.now(datetime.UTC) > expires - TOKEN_REFRESH_MARGIN:
t0 = time.monotonic()
token = self.issue_token(intended_audience, scopes)
dt_s = (dt_s or 0) + (time.monotonic() - t0)
self._tokens[intended_audience][scope_string] = token
return {"Authorization": "Bearer " + token}
return AdditionalHeaders(
headers={"Authorization": "Bearer " + token}, token_issuance_seconds=dt_s
)

def add_headers(self, request: requests.PreparedRequest, scopes: list[str]):
def add_headers(
self, request: requests.PreparedRequest, scopes: list[str]
) -> AdditionalHeaders:
if request.url:
for k, v in self.get_headers(request.url, scopes).items():
additional_headers = self.get_headers(request.url, scopes)
for k, v in additional_headers.headers.items():
request.headers[k] = v
return additional_headers
else:
return AdditionalHeaders(headers={})

def get_sub(self) -> str | None:
"""Retrieve `sub` claim from one of the existing tokens"""
Expand Down Expand Up @@ -182,6 +207,21 @@ def prepare_request(self, request, **kwargs):

return super().prepare_request(request, **kwargs)

def add_auth(
self, prepared_request: requests.PreparedRequest, scopes: list[str] | None
) -> requests.PreparedRequest:
if scopes and self.auth_adapter:
additional_headers = self.auth_adapter.add_headers(prepared_request, scopes)
if additional_headers.token_issuance_seconds:
setattr(
prepared_request,
AUTHORIZATION_DT,
datetime.timedelta(
seconds=additional_headers.token_issuance_seconds
),
)
return prepared_request

def adjust_request_kwargs(self, kwargs):
if self.auth_adapter:
scopes = None
Expand All @@ -194,14 +234,7 @@ def adjust_request_kwargs(self, kwargs):
if scopes is None:
scopes = self.default_scopes

def auth(
prepared_request: requests.PreparedRequest,
) -> requests.PreparedRequest:
if scopes and self.auth_adapter:
self.auth_adapter.add_headers(prepared_request, scopes)
return prepared_request

kwargs["auth"] = auth
kwargs["auth"] = lambda req: self.add_auth(req, scopes)
if "timeout" not in kwargs:
kwargs["timeout"] = self.timeout_seconds
return kwargs
Expand Down Expand Up @@ -295,10 +328,8 @@ def adjust_request_kwargs(self, url, method, kwargs):
raise ValueError(
"All tests must specify auth scope for all session requests. Either specify as an argument for each individual HTTP call, or decorate the test with @default_scope."
)
headers = {}
for k, v in self.auth_adapter.get_headers(url, scopes).items():
headers[k] = v
kwargs["headers"] = headers
additional_headers = self.auth_adapter.get_headers(url, scopes)
kwargs["headers"] = additional_headers.headers
if method == "PUT" and kwargs.get("data"):
kwargs["json"] = kwargs["data"]
del kwargs["data"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def subject(self) -> str:
# we force one using the client identify audience and scopes

# Trigger a caching initial token request so that adapter.get_sub() will return something
headers = self._adapter.get_headers(
additional_headers = self._adapter.get_headers(
f"https://{self.specification.whoami_audience}",
[self.specification.whoami_scope],
)
Expand All @@ -66,7 +66,7 @@ def subject(self) -> str:
raise ValueError(
f"subject is None, meaning `sub` claim was not found in payload of token, "
f"using {type(self._adapter).__name__} requesting {self.specification.whoami_scope} scope "
f"for {self.specification.whoami_audience} audience: {headers['Authorization'][len('Bearer: ') :]}"
f"for {self.specification.whoami_audience} audience: {additional_headers.headers['Authorization'][len('Bearer: ') :]}"
)

return sub
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ async def _get_isa(self, isa_id):
"GET",
url,
)
# TODO: Do not rely on a prepared request that is not actually used in order to create the Query RequestDescription; instead build it from the request actually made
prep = self._dss.client.prepare_request(r)
t0 = datetime.now(UTC)
req_descr = describe_request(prep, t0)
Expand Down Expand Up @@ -242,6 +243,7 @@ async def _create_isa(self, isa_id):
url,
json=payload,
)
# TODO: Do not rely on a prepared request that is not actually used in order to create the Query RequestDescription; instead build it from the request actually made
prep = self._dss.client.prepare_request(r)
t0 = datetime.now(UTC)
req_descr = describe_request(prep, t0)
Expand Down Expand Up @@ -272,6 +274,7 @@ async def _delete_isa(self, isa_id, isa_version):
"DELETE",
url,
)
# TODO: Do not rely on a prepared request that is not actually used in order to create the Query RequestDescription; instead build it from the request actually made
prep = self._dss.client.prepare_request(r)
t0 = datetime.now(UTC)
req_descr = describe_request(prep, t0)
Expand Down
20 changes: 12 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,22 @@ default-groups = []

[tool.ruff]
target-version = "py313"
extend-exclude = [
"interfaces/*",
"monitoring/prober/output/*",
"monitoring/mock_uss/output/*",
"monitoring/uss_qualifier/output/*",
]
line-length = 88

[tool.ruff.lint]
# Default + isort + pyupgrade
lint.select = [
select = [
"E4", "E7", "E9", "F", "I", "UP"
]
extend-exclude = [
"interfaces/*",
"monitoring/prober/output/*",
"monitoring/mock_uss/output/*",
"monitoring/uss_qualifier/output/*",
]
line-length = 88
# Explicitly ignore UP045 (Optional[Foo] -> Foo | None)
ignore = ["UP045"]
unfixable = ["UP045"]

[tool.basedpyright]
typeCheckingMode = "standard"
Expand Down
8 changes: 8 additions & 0 deletions schemas/monitoring/monitorlib/fetch/RequestDescription.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@
"description": "Path to content that replaces the $ref",
"type": "string"
},
"auth_dt": {
"description": "Amount of time required to obtain authorization before performing the primary query (de minimus or unknown by default).",
"format": "duration",
"type": [
"string",
"null"
]
},
"body": {
"type": [
"string",
Expand Down
Loading