77import time
88import urllib .parse
99import weakref
10+ from dataclasses import dataclass
1011from enum import Enum
1112
1213import jwt
2829CLIENT_TIMEOUT = 10 # seconds
2930SOCKET_KEEP_ALIVE_LIMIT = 57 # seconds.
3031
32+ AUTHORIZATION_DT = "authorization_dt"
33+ """This attribute may be added to a PreparedRequest indicating the timedelta required to obtain authorization"""
34+
3135
3236AuthSpec = str
3337"""Specification for means by which to obtain access tokens."""
3438
3539
40+ @dataclass
41+ class AdditionalHeaders :
42+ headers : dict [str , str ]
43+ token_issuance_seconds : float | None = None
44+
45+
3646class AuthAdapter :
3747 """Base class for an adapter that add JWTs to requests."""
3848
@@ -44,33 +54,48 @@ def issue_token(self, intended_audience: str, scopes: list[str]) -> str:
4454
4555 raise NotImplementedError ()
4656
47- def get_headers (self , url : str , scopes : list [str ] | None = None ) -> dict [str , str ]:
57+ def get_headers (
58+ self , url : str , scopes : list [str ] | None = None
59+ ) -> AdditionalHeaders :
4860 if scopes is None :
4961 scopes = ALL_SCOPES
5062 scopes = [s .value if isinstance (s , Enum ) else s for s in scopes ]
5163 intended_audience = urllib .parse .urlparse (url ).hostname
5264
5365 if not intended_audience :
54- return {}
66+ return AdditionalHeaders ( headers = {})
5567
5668 scope_string = " " .join (scopes )
5769 if intended_audience not in self ._tokens :
5870 self ._tokens [intended_audience ] = {}
5971 if scope_string not in self ._tokens [intended_audience ]:
72+ t0 = time .monotonic ()
6073 token = self .issue_token (intended_audience , scopes )
74+ dt_s = time .monotonic () - t0
6175 else :
6276 token = self ._tokens [intended_audience ][scope_string ]
77+ dt_s = None
6378 payload = jwt .decode (token , options = {"verify_signature" : False })
6479 expires = EPOCH + datetime .timedelta (seconds = payload ["exp" ])
6580 if datetime .datetime .now (datetime .UTC ) > expires - TOKEN_REFRESH_MARGIN :
81+ t0 = time .monotonic ()
6682 token = self .issue_token (intended_audience , scopes )
83+ dt_s = (dt_s or 0 ) + (time .monotonic () - t0 )
6784 self ._tokens [intended_audience ][scope_string ] = token
68- return {"Authorization" : "Bearer " + token }
85+ return AdditionalHeaders (
86+ headers = {"Authorization" : "Bearer " + token }, token_issuance_seconds = dt_s
87+ )
6988
70- def add_headers (self , request : requests .PreparedRequest , scopes : list [str ]):
89+ def add_headers (
90+ self , request : requests .PreparedRequest , scopes : list [str ]
91+ ) -> AdditionalHeaders :
7192 if request .url :
72- for k , v in self .get_headers (request .url , scopes ).items ():
93+ additional_headers = self .get_headers (request .url , scopes )
94+ for k , v in additional_headers .headers .items ():
7395 request .headers [k ] = v
96+ return additional_headers
97+ else :
98+ return AdditionalHeaders (headers = {})
7499
75100 def get_sub (self ) -> str | None :
76101 """Retrieve `sub` claim from one of the existing tokens"""
@@ -182,6 +207,21 @@ def prepare_request(self, request, **kwargs):
182207
183208 return super ().prepare_request (request , ** kwargs )
184209
210+ def add_auth (
211+ self , prepared_request : requests .PreparedRequest , scopes : list [str ] | None
212+ ) -> requests .PreparedRequest :
213+ if scopes and self .auth_adapter :
214+ additional_headers = self .auth_adapter .add_headers (prepared_request , scopes )
215+ if additional_headers .token_issuance_seconds :
216+ setattr (
217+ prepared_request ,
218+ AUTHORIZATION_DT ,
219+ datetime .timedelta (
220+ seconds = additional_headers .token_issuance_seconds
221+ ),
222+ )
223+ return prepared_request
224+
185225 def adjust_request_kwargs (self , kwargs ):
186226 if self .auth_adapter :
187227 scopes = None
@@ -194,14 +234,7 @@ def adjust_request_kwargs(self, kwargs):
194234 if scopes is None :
195235 scopes = self .default_scopes
196236
197- def auth (
198- prepared_request : requests .PreparedRequest ,
199- ) -> requests .PreparedRequest :
200- if scopes and self .auth_adapter :
201- self .auth_adapter .add_headers (prepared_request , scopes )
202- return prepared_request
203-
204- kwargs ["auth" ] = auth
237+ kwargs ["auth" ] = lambda req : self .add_auth (req , scopes )
205238 if "timeout" not in kwargs :
206239 kwargs ["timeout" ] = self .timeout_seconds
207240 return kwargs
@@ -295,10 +328,8 @@ def adjust_request_kwargs(self, url, method, kwargs):
295328 raise ValueError (
296329 "All tests must specify auth scope for all session requests. Either specify as an argument for each individual HTTP call, or decorate the test with @default_scope."
297330 )
298- headers = {}
299- for k , v in self .auth_adapter .get_headers (url , scopes ).items ():
300- headers [k ] = v
301- kwargs ["headers" ] = headers
331+ additional_headers = self .auth_adapter .get_headers (url , scopes )
332+ kwargs ["headers" ] = additional_headers .headers
302333 if method == "PUT" and kwargs .get ("data" ):
303334 kwargs ["json" ] = kwargs ["data" ]
304335 del kwargs ["data" ]
0 commit comments