From 6f497a4b2877a5fb36109eff6718497246ff429f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wolfgang=20H=C3=BCbner?= Date: Fri, 10 Jun 2022 14:05:04 +0200 Subject: [PATCH 1/6] Adding code flow auth and org switch * Added CodeFlowAuthTokenApiHandler to obtain a token via OAuth2 authorization code. * Added organization and organization_id to token and refresh_token requests. --- src/hiro_graph_client/__init__.py | 10 +- src/hiro_graph_client/clientlib.py | 484 +++++++++++++++++++++++------ 2 files changed, 396 insertions(+), 98 deletions(-) diff --git a/src/hiro_graph_client/__init__.py b/src/hiro_graph_client/__init__.py index 9e05f8d..f80ad3d 100644 --- a/src/hiro_graph_client/__init__.py +++ b/src/hiro_graph_client/__init__.py @@ -9,8 +9,8 @@ from hiro_graph_client.authzclient import HiroAuthz from hiro_graph_client.client import HiroGraph from hiro_graph_client.clientlib import AbstractTokenApiHandler, GraphConnectionHandler, AuthenticationTokenError, \ - FixedTokenError, TokenUnauthorizedError, PasswordAuthTokenApiHandler, FixedTokenApiHandler, \ - EnvironmentTokenApiHandler, SSLConfig + FixedTokenError, TokenUnauthorizedError, PasswordAuthTokenApiHandler, CodeFlowAuthTokenApiHandler, \ + FixedTokenApiHandler, EnvironmentTokenApiHandler, SSLConfig from hiro_graph_client.iamclient import HiroIam from hiro_graph_client.kiclient import HiroKi from hiro_graph_client.variablesclient import HiroVariables @@ -20,9 +20,9 @@ __all__ = [ 'HiroGraph', 'HiroAuth', 'HiroApp', 'HiroIam', 'HiroKi', 'HiroAuthz', 'HiroVariables', 'GraphConnectionHandler', - 'AbstractTokenApiHandler', 'PasswordAuthTokenApiHandler', 'FixedTokenApiHandler', 'EnvironmentTokenApiHandler', - 'AuthenticationTokenError', 'FixedTokenError', 'TokenUnauthorizedError', '__version__', - 'SSLConfig' + 'AbstractTokenApiHandler', 'PasswordAuthTokenApiHandler', 'CodeFlowAuthTokenApiHandler', 'FixedTokenApiHandler', + 'EnvironmentTokenApiHandler', 'AuthenticationTokenError', 'FixedTokenError', 'TokenUnauthorizedError', + '__version__', 'SSLConfig' ] site.addsitedir(this_directory) diff --git a/src/hiro_graph_client/clientlib.py b/src/hiro_graph_client/clientlib.py index 96eac8c..d9aa02b 100644 --- a/src/hiro_graph_client/clientlib.py +++ b/src/hiro_graph_client/clientlib.py @@ -172,7 +172,6 @@ def __init__(self, max_tries = abstract_api._max_tries else: initial_headers = { - 'Content-Type': 'application/json', 'Accept': 'text/plain, application/json', 'User-Agent': f"{client_name or self._client_name} {__version__}" } @@ -215,21 +214,23 @@ def _capitalize_header(name: str) -> str: # Basic requests ############################################################################################################### - def get_binary(self, url: str, accept: str = None) -> Iterator[bytes]: + def get_binary(self, url: str, accept: str = None, headers: dict = None) -> Iterator[bytes]: """ Implementation of GET for binary data. :param url: Url to use :param accept: Mimetype for accept. Will be set to */* if not given. + :param headers: Optional additional headers. :return: Yields over raw chunks of the response payload. """ @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _get_binary() -> Iterator[bytes]: + _headers: dict = {"Content-Type": None, "Accept": (accept or "*/*")} + if headers: + _headers.update(headers) with self._session.get(url, - headers=self._get_headers( - {"Content-Type": None, "Accept": (accept or "*/*")} - ), + headers=self._get_headers(_headers), verify=self.ssl_config.get_verify(), cert=self.ssl_config.get_cert(), timeout=self._timeout, @@ -247,7 +248,8 @@ def post_binary(self, url: str, data: Any, content_type: str = None, - expected_media_type: str = 'application/json') -> Any: + expected_media_type: str = 'application/json', + headers: dict = None) -> Any: """ Implementation of POST for binary data. @@ -256,16 +258,18 @@ def post_binary(self, :param content_type: The content type of the data. Defaults to "application/octet-stream" internally if unset. :param expected_media_type: The expected media type. Default is 'application/json'. If this is set to '*' or '*/*', any media_type is accepted. + :param headers: Optional additional headers. :return: The payload of the response """ @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _post_binary() -> Any: + _headers: dict = {"Content-Type": (content_type or "application/octet-stream")} + if headers: + _headers.update(headers) res = self._session.post(url, data=data, - headers=self._get_headers( - {"Content-Type": (content_type or "application/octet-stream")} - ), + headers=self._get_headers(_headers), verify=self.ssl_config.get_verify(), cert=self.ssl_config.get_cert(), timeout=self._timeout, @@ -279,7 +283,8 @@ def put_binary(self, url: str, data: Any, content_type: str = None, - expected_media_type: str = 'application/json') -> Any: + expected_media_type: str = 'application/json', + headers: dict = None) -> Any: """ Implementation of PUT for binary data. @@ -288,16 +293,18 @@ def put_binary(self, :param content_type: The content type of the data. Defaults to "application/octet-stream" internally if unset. :param expected_media_type: The expected media type. Default is 'application/json'. If this is set to '*' or '*/*', any media_type is accepted. + :param headers: Optional additional headers. :return: The payload of the response """ @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _put_binary() -> Any: + _headers: dict = {"Content-Type": (content_type or "application/octet-stream")} + if headers: + _headers.update(headers) res = self._session.put(url, data=data, - headers=self._get_headers( - {"Content-Type": (content_type or "application/octet-stream")} - ), + headers=self._get_headers(_headers), verify=self.ssl_config.get_verify(), cert=self.ssl_config.get_cert(), timeout=self._timeout, @@ -335,7 +342,9 @@ def _get() -> Any: def post(self, url: str, data: Any, - expected_media_type: str = 'application/json') -> Any: + expected_media_type: str = 'application/json', + content_type: str = 'application/json', + headers: dict = None) -> Any: """ Implementation of POST @@ -343,18 +352,35 @@ def post(self, :param data: The payload to POST :param expected_media_type: The expected media type. Default is 'application/json'. If this is set to '*' or '*/*', any media_type is accepted. + :param content_type: The content type to send. For data of type dict: If 'application/json' is used, the request + will be using JSON formatting (requests parameter json=), x-www-form-urlencoded will be used otherwise + (requests parameter data=). Default is application/json. + :param headers: Optional additional headers. :return: The payload of the response """ @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _post() -> Any: - res = self._session.post(url, - json=data, - headers=self._get_headers(), - verify=self.ssl_config.get_verify(), - cert=self.ssl_config.get_cert(), - timeout=self._timeout, - proxies=self._get_proxies()) + _headers: dict = {"Content-Type": content_type} + if headers: + _headers.update(headers) + if content_type == 'application/json': + res = self._session.post(url, + json=data, + headers=self._get_headers(_headers), + verify=self.ssl_config.get_verify(), + cert=self.ssl_config.get_cert(), + timeout=self._timeout, + proxies=self._get_proxies()) + else: + res = self._session.post(url, + data=data, + headers=self._get_headers(_headers), + verify=self.ssl_config.get_verify(), + cert=self.ssl_config.get_cert(), + timeout=self._timeout, + proxies=self._get_proxies()) + self._log_communication(res) return self._parse_response(res, expected_media_type) @@ -363,7 +389,9 @@ def _post() -> Any: def put(self, url: str, data: Any, - expected_media_type: str = 'application/json') -> Any: + expected_media_type: str = 'application/json', + content_type: str = 'application/json', + headers: dict = None) -> Any: """ Implementation of PUT @@ -371,18 +399,35 @@ def put(self, :param data: The payload to PUT :param expected_media_type: The expected media type. Default is 'application/json'. If this is set to '*' or '*/*', any media_type is accepted. + :param content_type: The content type to send. For data of type dict: If 'application/json' is used, the request + will be using JSON formatting (requests parameter json=), x-www-form-urlencoded will be used otherwise + (requests parameter data=). Default is application/json. + :param headers: Optional additional headers. :return: The payload of the response """ @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _put() -> Any: - res = self._session.put(url, - json=data, - headers=self._get_headers(), - verify=self.ssl_config.get_verify(), - cert=self.ssl_config.get_cert(), - timeout=self._timeout, - proxies=self._get_proxies()) + _headers: dict = {"Content-Type": content_type} + if headers: + _headers.update(headers) + if content_type == 'application/json': + res = self._session.put(url, + json=data, + headers=self._get_headers(), + verify=self.ssl_config.get_verify(), + cert=self.ssl_config.get_cert(), + timeout=self._timeout, + proxies=self._get_proxies()) + else: + res = self._session.put(url, + data=data, + headers=self._get_headers({"Content-Type": content_type}), + verify=self.ssl_config.get_verify(), + cert=self.ssl_config.get_cert(), + timeout=self._timeout, + proxies=self._get_proxies()) + self._log_communication(res) return self._parse_response(res, expected_media_type) @@ -391,7 +436,9 @@ def _put() -> Any: def patch(self, url: str, data: Any, - expected_media_type: str = 'application/json') -> Any: + expected_media_type: str = 'application/json', + content_type: str = 'application/json', + headers: dict = None) -> Any: """ Implementation of PATCH @@ -399,18 +446,34 @@ def patch(self, :param data: The payload to PUT :param expected_media_type: The expected media type. Default is 'application/json'. If this is set to '*' or '*/*', any media_type is accepted. + :param content_type: The content type to send. For data of type dict: If 'application/json' is used, the request + will be using JSON formatting (requests parameter json=), x-www-form-urlencoded will be used otherwise + (requests parameter data=). Default is application/json. + :param headers: Optional additional headers. :return: The payload of the response """ @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _patch() -> Any: - res = self._session.patch(url, - json=data, - headers=self._get_headers(), - verify=self.ssl_config.get_verify(), - cert=self.ssl_config.get_cert(), - timeout=self._timeout, - proxies=self._get_proxies()) + _headers: dict = {"Content-Type": content_type} + if headers: + _headers.update(headers) + if content_type == 'application/json': + res = self._session.patch(url, + json=data, + headers=self._get_headers(_headers), + verify=self.ssl_config.get_verify(), + cert=self.ssl_config.get_cert(), + timeout=self._timeout, + proxies=self._get_proxies()) + else: + res = self._session.patch(url, + data=data, + headers=self._get_headers(_headers), + verify=self.ssl_config.get_verify(), + cert=self.ssl_config.get_cert(), + timeout=self._timeout, + proxies=self._get_proxies()) self._log_communication(res) return self._parse_response(res, expected_media_type) @@ -725,7 +788,7 @@ def __init__(self, version_info: dict = None, pool_maxsize: int = None, pool_block: bool = None, - connection_handler=None, + connection_handler = None, *args, **kwargs): """ @@ -966,9 +1029,15 @@ def decode_token_ext(token: str) -> dict: return dict(json.loads(json_payload)) @abstractmethod - def refresh_token(self) -> None: + def refresh_token(self, organization: str = None, organization_id: str = None) -> None: """ Refresh the current token. + + Organization information does not overwrite the internal organization or organization_id. + + :param organization: Optional name of an organization to be used for the token. + :param organization_id: Optional id of an organization to be used for the token. Overrides + parameter *organization*. """ raise RuntimeError('Cannot use method of this abstract class.') @@ -1019,7 +1088,7 @@ def __init__(self, token: str = None, *args, **kwargs): def token(self) -> str: return self._token - def refresh_token(self) -> None: + def refresh_token(self, organization: str = None, organization_id: str = None) -> None: raise FixedTokenError('Token is invalid and cannot be changed because it has been given externally.') def revoke_token(self, token_hint: str = "revoke_token") -> None: @@ -1060,7 +1129,7 @@ def __init__(self, env_var: str = 'HIRO_TOKEN', *args, **kwargs): def token(self) -> str: return os.environ[self._env_var] - def refresh_token(self) -> None: + def refresh_token(self, organization: str = None, organization_id: str = None) -> None: raise FixedTokenError( "Token is invalid and cannot be changed because it has been given as environment variable '{}'" " externally.".format(self._env_var)) @@ -1187,10 +1256,9 @@ def clear_token_data(self, access_token_only: bool): self.expires_at = -1 -class PasswordAuthTokenApiHandler(AbstractTokenApiHandler): +class AbstractRemoteTokenApiHandler(AbstractTokenApiHandler): """ - API Tokens will be fetched using this class. It does not handle any automatic token fetching, refresh or token - expiry. This has to be checked and triggered by the *caller*. + Remote API Tokens will be fetched using this class. The methods of this class are thread-safe, so it can be shared between several HIRO objects. @@ -1203,19 +1271,20 @@ class PasswordAuthTokenApiHandler(AbstractTokenApiHandler): _lock: threading.RLock """Reentrant mutex for thread safety""" - _username: str - _password: str _client_id: str _client_secret: str + _organization: str + _organization_id: str + _secure_logging: bool = True """Avoid logging of sensitive data.""" def __init__(self, - username: str = None, - password: str = None, client_id: str = None, client_secret: str = None, + organization: str = None, + organization_id: str = None, secure_logging: bool = True, *args, **kwargs): """ @@ -1224,20 +1293,21 @@ def __init__(self, See parent :class:`AbstractTokenApiHandler` for a full description of all remaining parameters. - :param username: Username for authentication - :param password: Password for authentication :param client_id: OAuth client_id for authentication - :param client_secret: OAuth client_secret for authentication + :param client_secret: OAuth client_secret for authentication (This can be None in special cases). + :param organization: Optional name of an organization to be used for the token requests. + :param organization_id: Optional id of an organization to be used for the token requests. Overrides + parameter *organization*. :param secure_logging: If this is enabled, payloads that might contain sensitive information are not logged. - :param args: Unnamed parameter passthrough for parent class. - :param kwargs: Named parameter passthrough for parent class. + :param args: Unnamed parameter passthrough for parent class. + :param kwargs: Named parameter passthrough for parent class. """ super().__init__(*args, **kwargs) - self._username = username - self._password = password self._client_id = client_id self._client_secret = client_secret + self._organization = organization + self._organization_id = organization_id self._secure_logging = secure_logging @@ -1274,44 +1344,29 @@ def _log_communication(self, res: requests.Response, request_body: bool = True, super()._log_communication(res, request_body=log_request_body, response_body=log_response_body) - def get_token(self) -> None: + @abstractmethod + def get_token(self, organization: str = None, organization_id: str = None) -> None: """ - Construct a request to obtain a new token. API self._endpoint + '/app' + Construct a request to obtain a new token. + + :param organization: Optional name of an organization to be used for this token. + :param organization_id: Optional id of an organization to be used for this token. Overrides + parameter *organization*. :raises AuthenticationTokenError: When no auth_endpoint is set. """ - with self._lock: - if not self.endpoint: - raise AuthenticationTokenError( - 'Token is invalid and endpoint (auth_endpoint) for obtaining is not set.') - - if not self._username or not self._password or not self._client_id or not self._client_secret: - msg = "" - if not self._username: - msg += "'username'" - if not self._password: - msg += (", " if msg else "") + "'password'" - if not self._client_id: - msg += (", " if msg else "") + "'client_id'" - if not self._client_secret: - msg += (", " if msg else "") + "'client_secret'" - raise AuthenticationTokenError( - "{} is missing required parameter(s) {}.".format(self.__class__.__name__, msg)) + raise RuntimeError('Cannot use method of this abstract class.') - url = self.endpoint + '/app' - data = { - "client_id": self._client_id, - "client_secret": self._client_secret, - "username": self._username, - "password": self._password - } + def refresh_token(self, organization: str = None, organization_id: str = None) -> None: + """ + Construct a request to refresh an existing token. - res = self.post(url, data) - self._token_info.parse_token_result(res, "{}.get_token".format(self.__class__.__name__)) + API self._endpoint + '/refresh'. (until /api/auth/6.5) + API self._endpoint + '/token'. (since /api/auth/6.6) - def refresh_token(self) -> None: - """ - Construct a request to refresh an existing token. API self._endpoint + '/refresh'. + :param organization: Optional name of an organization to be used for this token. + :param organization_id: Optional id of an organization to be used for this token. Overrides + parameter *organization*. :raises AuthenticationTokenError: When no auth_endpoint is set. """ @@ -1321,21 +1376,38 @@ def refresh_token(self) -> None: 'Token is invalid and endpoint (auth_endpoint) for refresh is not set.') if not self._token_info.refresh_token: - self.get_token() + self.get_token(organization=organization, organization_id=organization_id) return - url = self.endpoint + '/refresh' - data = { + auth_api_version = float(self._version_info['auth']['version']) + + if auth_api_version >= 6.6: + url = self.endpoint + '/token' + content_type = 'application/x-www-form-urlencoded' + data = { + "grant_type": "refresh_token" + } + else: + url = self.endpoint + '/refresh' + content_type = 'application/json' + data = {} + + self._organization = organization + self._organization_id = organization_id + + data.update({ "client_id": self._client_id, "client_secret": self._client_secret, - "refresh_token": self._token_info.refresh_token - } + "refresh_token": self._token_info.refresh_token, + "organization": self._organization, + "organization_id": self._organization_id + }) try: - res = self.post(url, data) + res = self.post(url, data, content_type=content_type) self._token_info.parse_token_result(res, "{}.refresh_token".format(self.__class__.__name__)) except AuthenticationTokenError: - self.get_token() + self.get_token(organization=organization, organization_id=organization_id) def revoke_token(self, token_hint: str = "refresh_token") -> None: """ @@ -1413,6 +1485,232 @@ def _handle_token(self) -> Optional[str]: return None +class PasswordAuthTokenApiHandler(AbstractRemoteTokenApiHandler): + """ + Implements the OAuth2 password auth flow. + + The methods of this class are thread-safe, so it can be shared between several HIRO objects. + + It is built this way to avoid endless calling loops when resolving tokens. + """ + _username: str + _password: str + + def __init__(self, + username: str = None, + password: str = None, + connection_handler=None, + *args, **kwargs): + """ + Constructor + + See parent :class:`AbstractRemoteTokenApiHandler` for a full description + of all remaining parameters. + + :param username: Username for authentication + :param password: Password for authentication + :param client_id: OAuth client_id for authentication + :param client_secret: OAuth client_secret for authentication + :param organization: Optional name of an organization to be used for the token requests. + :param organization_id: Optional id of an organization to be used for the token requests. Overrides + parameter *organization*. + :param secure_logging: If this is enabled, payloads that might contain sensitive information are not logged. + :param args: Unnamed parameter passthrough for parent class. + :param kwargs: Named parameter passthrough for parent class. + """ + super().__init__(connection_handler=connection_handler, *args, **kwargs) + + self._username = username + self._password = password + + def get_token(self, organization: str = None, organization_id: str = None) -> None: + """ + Construct a request to obtain a new token. This is the OAuth2 password auth flow. + + API self._endpoint + '/app' (until /api/auth/6.5) + API self._endpoint + '/token'. (since /api/auth/6.6) + + :param organization: Optional name of an organization to be used for this token. + :param organization_id: Optional id of an organization to be used for this token. + :raises AuthenticationTokenError: When no auth_endpoint is set. + """ + with self._lock: + if not self.endpoint: + raise AuthenticationTokenError( + 'Token is invalid and endpoint (auth_endpoint) for obtaining is not set.') + + if not self._username or not self._password or not self._client_id or not self._client_secret: + msg = "" + if not self._username: + msg += "'username'" + if not self._password: + msg += (", " if msg else "") + "'password'" + if not self._client_id: + msg += (", " if msg else "") + "'client_id'" + if not self._client_secret: + msg += (", " if msg else "") + "'client_secret'" + raise AuthenticationTokenError( + "{} is missing required parameter(s) {}.".format(self.__class__.__name__, msg)) + + auth_api_version = float(self._version_info['auth']['version']) + + if auth_api_version >= 6.6: + url = self.endpoint + '/token' + content_type = 'application/x-www-form-urlencoded' + data = { + "grant_type": "password" + } + else: + url = self.endpoint + '/app' + content_type = 'application/json' + data = {} + + self._organization = organization + self._organization_id = organization_id + + data.update({ + "client_id": self._client_id, + "client_secret": self._client_secret, + "username": self._username, + "password": self._password, + "organization": self._organization, + "organization_id": self._organization_id + }) + + res = self.post(url, data, content_type=content_type) + self._token_info.parse_token_result(res, "{}.get_token".format(self.__class__.__name__)) + + +class CodeFlowAuthTokenApiHandler(AbstractRemoteTokenApiHandler): + """ + Implements the OAuth2 authorization_code auth flow. This is only the second step where you already + obtained the code parameter from the authorization server. + + Be aware, that obtaining a token can only be used once with the same value of *self._code*. + + The methods of this class are thread-safe, so it can be shared between several HIRO objects. + + It is built this way to avoid endless calling loops when resolving tokens. + """ + _code: str + _code_verifier: str + _redirect_uri: str + + def __init__(self, + code: str = None, + redirect_uri: str = None, + code_verifier: str = None, + connection_handler=None, + *args, **kwargs): + """ + Constructor + + Be aware, that obtaining a token via authentication_code flow can only be done once with the same value of + *self._code*. If a token refresh fails, you need to start a new authorization code flow. + + See parent :class:`AbstractRemoteTokenApiHandler` for a full description + of all remaining parameters. + + :param code: One time code received from the authorization server. + :param redirect_uri: The original redirect_uri parameter from the authorization-redirect call (first call of the + code flow not handled here). + :param code_verifier: The code_verifier for the PKCE code flow. + :param client_id: OAuth client_id for authentication + :param client_secret: OAuth client_secret for authentication. This is optional here. + :param organization: Optional name of an organization to be used for the token requests. + :param organization_id: Optional id of an organization to be used for the token requests. Overrides + parameter *organization*. + :param secure_logging: If this is enabled, payloads that might contain sensitive information are not logged. + :param args: Unnamed parameter passthrough for parent class. + :param kwargs: Named parameter passthrough for parent class. + """ + super().__init__(connection_handler=connection_handler, *args, **kwargs) + + self._code = code + self._code_verifier = code_verifier + self._redirect_uri = redirect_uri + + def get_token(self, organization: str = None, organization_id: str = None) -> None: + """ + Construct a request to obtain a new token. This is the second step of the authorization_code flow. + + Be aware, that this method can only be used once with the same value of *self._code*. If a token has been + obtained here, this method cannot be re-used. If a token refresh fails, you need to start a new authorization + code flow. + + API self._endpoint + '/token' (since /auth/api/6.6) + + :param organization: Optional name of an organization to be used for this token. + :param organization_id: Optional id of an organization to be used for this token. + :raises AuthenticationTokenError: When no auth_endpoint is set or /api/auth/${version} is below 6.6. + """ + with self._lock: + if not self.endpoint: + raise AuthenticationTokenError( + 'Token is invalid and endpoint (auth_endpoint) for obtaining is not set.') + + if not self._code or not self._code_verifier or not self._client_id: + msg = "" + if not self._code: + msg += "'code'" + if not self._client_id: + msg += (", " if msg else "") + "'client_id'" + if not self._redirect_uri: + msg += (", " if msg else "") + "'redirect_uri'" + raise AuthenticationTokenError( + "{} is missing required parameter(s) {}.".format(self.__class__.__name__, msg)) + + auth_api_version = float(self._version_info['auth']['version']) + if auth_api_version < 6.6: + raise AuthenticationTokenError("Auth api version /api/auth/[version] has to be at least 6.6.") + + url = self.endpoint + '/token' + + # If a client_secret is present, set "client_id" and "client_secret", if not, only add "clientId" as + # form param (peculiar of WSO2). + if self._client_secret: + data = { + "client_id": self._client_id, + "client_secret": self._client_secret + } + else: + data = { + "clientId": self._client_id + } + + self._organization = organization + self._organization_id = organization_id + + data.update({ + "grant_type": "authorization_code", + "code": self._code, + "code_verifier": self._code_verifier, + "redirect_uri": self._redirect_uri, + "organization": self._organization, + "organization_id": self._organization_id + }) + + res = self.post(url, data, content_type='application/x-www-form-urlencoded') + self._token_info.parse_token_result(res, "{}.get_token".format(self.__class__.__name__)) + + def refresh_token(self, organization: str = None, organization_id: str = None) -> None: + """ + Construct a request to refresh an existing token. This fails immediately if no refresh_token is available. + + API self._endpoint + '/token'. (since /api/auth/6.6) + + :param organization: Optional name of an organization to be used for this token. + :param organization_id: Optional id of an organization to be used for this token. Overrides + parameter *organization*. + + :raises AuthenticationTokenError: When no auth_endpoint or refresh_token is set. + """ + if not self._token_info.refresh_token: + raise AuthenticationTokenError("Token cannot be refreshed without a refresh_token.") + + super().refresh_token(organization, organization_id) + + ################################################################################################################### # Root class for different API groups ################################################################################################################### From f6d486fa689a0c51b121781510586f79a9618b63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wolfgang=20H=C3=BCbner?= Date: Fri, 10 Jun 2022 15:38:55 +0200 Subject: [PATCH 2/6] Debugged header handling * Added header handling to all HTTP calls. Make sure, that parameter content_type and keys in header dict do not collide. --- src/hiro_graph_client/clientlib.py | 103 +++++++++++++++++++++-------- 1 file changed, 75 insertions(+), 28 deletions(-) diff --git a/src/hiro_graph_client/clientlib.py b/src/hiro_graph_client/clientlib.py index d9aa02b..3fe259d 100644 --- a/src/hiro_graph_client/clientlib.py +++ b/src/hiro_graph_client/clientlib.py @@ -210,6 +210,34 @@ def user_agent(self): def _capitalize_header(name: str) -> str: return "-".join([n.capitalize() for n in name.split('-')]) + @staticmethod + def _handle_content_type_and_headers(content_type: Optional[str], + extern_headers: dict = None, + initial_headers: dict = None) -> dict: + """ + Handle header merging and content_type parameter in basic requests below. + + :param content_type: The desired content_type for the call. This can explicitly set to None to erase it from the + headers in the final http call. This value will always be present under key 'Content-Type' in the + returned dict and overwrites every other specification in the header dicts unless set to None. + :param extern_headers: Optional external headers that need to me merged with the initial headers. extern_headers + will be merged upon initial_headers, overwriting values within it. + :param initial_headers: Headers which are specific for the type of http call. + :return: A dict with headers. Will always contain a key "Content-Type" - the value of which might be None. + """ + final_headers: dict = initial_headers or {} + if extern_headers: + final_headers.update(extern_headers) + + for name in final_headers: + if name.lower() == "content-type": + content_type = content_type or final_headers[name] + del final_headers[name] + + final_headers.update({"Content-Type": content_type}) + + return final_headers + ############################################################################################################### # Basic requests ############################################################################################################### @@ -226,9 +254,11 @@ def get_binary(self, url: str, accept: str = None, headers: dict = None) -> Iter @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _get_binary() -> Iterator[bytes]: - _headers: dict = {"Content-Type": None, "Accept": (accept or "*/*")} - if headers: - _headers.update(headers) + _headers = AbstractAPI._handle_content_type_and_headers( + None, + headers, + {"Accept": (accept or "*/*")} + ) with self._session.get(url, headers=self._get_headers(_headers), verify=self.ssl_config.get_verify(), @@ -264,9 +294,10 @@ def post_binary(self, @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _post_binary() -> Any: - _headers: dict = {"Content-Type": (content_type or "application/octet-stream")} - if headers: - _headers.update(headers) + _headers = AbstractAPI._handle_content_type_and_headers( + content_type or "application/octet-stream", + headers + ) res = self._session.post(url, data=data, headers=self._get_headers(_headers), @@ -299,9 +330,10 @@ def put_binary(self, @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _put_binary() -> Any: - _headers: dict = {"Content-Type": (content_type or "application/octet-stream")} - if headers: - _headers.update(headers) + _headers = AbstractAPI._handle_content_type_and_headers( + content_type or "application/octet-stream", + headers + ) res = self._session.put(url, data=data, headers=self._get_headers(_headers), @@ -316,20 +348,26 @@ def _put_binary() -> Any: def get(self, url: str, - expected_media_type: str = 'application/json') -> Any: + expected_media_type: str = 'application/json', + headers: dict = None) -> Any: """ Implementation of GET :param url: Url to use :param expected_media_type: The expected media type. Default is 'application/json'. If this is set to '*' or '*/*', any media_type is accepted. + :param headers: Optional additional headers. :return: The payload of the response """ @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _get() -> Any: + _headers = AbstractAPI._handle_content_type_and_headers( + None, + headers + ) res = self._session.get(url, - headers=self._get_headers({"Content-Type": None}), + headers=self._get_headers(_headers), verify=self.ssl_config.get_verify(), cert=self.ssl_config.get_cert(), timeout=self._timeout, @@ -361,10 +399,11 @@ def post(self, @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _post() -> Any: - _headers: dict = {"Content-Type": content_type} - if headers: - _headers.update(headers) - if content_type == 'application/json': + _headers = AbstractAPI._handle_content_type_and_headers( + content_type, + headers + ) + if _headers["Content-Type"].startswith('application/json'): res = self._session.post(url, json=data, headers=self._get_headers(_headers), @@ -408,13 +447,14 @@ def put(self, @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _put() -> Any: - _headers: dict = {"Content-Type": content_type} - if headers: - _headers.update(headers) - if content_type == 'application/json': + _headers = AbstractAPI._handle_content_type_and_headers( + content_type, + headers + ) + if _headers["Content-Type"].startswith('application/json'): res = self._session.put(url, json=data, - headers=self._get_headers(), + headers=self._get_headers(_headers), verify=self.ssl_config.get_verify(), cert=self.ssl_config.get_cert(), timeout=self._timeout, @@ -422,7 +462,7 @@ def _put() -> Any: else: res = self._session.put(url, data=data, - headers=self._get_headers({"Content-Type": content_type}), + headers=self._get_headers(_headers), verify=self.ssl_config.get_verify(), cert=self.ssl_config.get_cert(), timeout=self._timeout, @@ -455,10 +495,11 @@ def patch(self, @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _patch() -> Any: - _headers: dict = {"Content-Type": content_type} - if headers: - _headers.update(headers) - if content_type == 'application/json': + _headers = AbstractAPI._handle_content_type_and_headers( + content_type, + headers + ) + if _headers["Content-Type"].startswith('application/json'): res = self._session.patch(url, json=data, headers=self._get_headers(_headers), @@ -481,20 +522,26 @@ def _patch() -> Any: def delete(self, url: str, - expected_media_type: str = 'application/json') -> Any: + expected_media_type: str = 'application/json', + headers: dict = None) -> Any: """ Implementation of DELETE :param url: Url to use :param expected_media_type: The expected media type. Default is 'application/json'. If this is set to '*' or '*/*', any media_type is accepted. + :param headers: Optional additional headers. :return: The payload of the response """ @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _delete() -> Any: + _headers = AbstractAPI._handle_content_type_and_headers( + None, + headers + ) res = self._session.delete(url, - headers=self._get_headers({"Content-Type": None}), + headers=self._get_headers(_headers), verify=self.ssl_config.get_verify(), cert=self.ssl_config.get_cert(), timeout=self._timeout, @@ -788,7 +835,7 @@ def __init__(self, version_info: dict = None, pool_maxsize: int = None, pool_block: bool = None, - connection_handler = None, + connection_handler=None, *args, **kwargs): """ From cd00886bfa73e0436ad8dc7c48df7a6162862fc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wolfgang=20H=C3=BCbner?= Date: Fri, 10 Jun 2022 16:12:30 +0200 Subject: [PATCH 3/6] Refactorings and cleanup --- src/hiro_graph_client/clientlib.py | 117 +++++++++++------------------ 1 file changed, 44 insertions(+), 73 deletions(-) diff --git a/src/hiro_graph_client/clientlib.py b/src/hiro_graph_client/clientlib.py index 3fe259d..f0be706 100644 --- a/src/hiro_graph_client/clientlib.py +++ b/src/hiro_graph_client/clientlib.py @@ -210,34 +210,6 @@ def user_agent(self): def _capitalize_header(name: str) -> str: return "-".join([n.capitalize() for n in name.split('-')]) - @staticmethod - def _handle_content_type_and_headers(content_type: Optional[str], - extern_headers: dict = None, - initial_headers: dict = None) -> dict: - """ - Handle header merging and content_type parameter in basic requests below. - - :param content_type: The desired content_type for the call. This can explicitly set to None to erase it from the - headers in the final http call. This value will always be present under key 'Content-Type' in the - returned dict and overwrites every other specification in the header dicts unless set to None. - :param extern_headers: Optional external headers that need to me merged with the initial headers. extern_headers - will be merged upon initial_headers, overwriting values within it. - :param initial_headers: Headers which are specific for the type of http call. - :return: A dict with headers. Will always contain a key "Content-Type" - the value of which might be None. - """ - final_headers: dict = initial_headers or {} - if extern_headers: - final_headers.update(extern_headers) - - for name in final_headers: - if name.lower() == "content-type": - content_type = content_type or final_headers[name] - del final_headers[name] - - final_headers.update({"Content-Type": content_type}) - - return final_headers - ############################################################################################################### # Basic requests ############################################################################################################### @@ -254,13 +226,11 @@ def get_binary(self, url: str, accept: str = None, headers: dict = None) -> Iter @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _get_binary() -> Iterator[bytes]: - _headers = AbstractAPI._handle_content_type_and_headers( - None, - headers, - {"Accept": (accept or "*/*")} - ) + _headers = self._get_headers(None, headers) + _headers.update({"Accept": (accept or "*/*")}) + with self._session.get(url, - headers=self._get_headers(_headers), + headers=_headers, verify=self.ssl_config.get_verify(), cert=self.ssl_config.get_cert(), timeout=self._timeout, @@ -294,13 +264,12 @@ def post_binary(self, @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _post_binary() -> Any: - _headers = AbstractAPI._handle_content_type_and_headers( - content_type or "application/octet-stream", - headers - ) res = self._session.post(url, data=data, - headers=self._get_headers(_headers), + headers=self._get_headers( + content_type or "application/octet-stream", + headers + ), verify=self.ssl_config.get_verify(), cert=self.ssl_config.get_cert(), timeout=self._timeout, @@ -330,13 +299,12 @@ def put_binary(self, @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _put_binary() -> Any: - _headers = AbstractAPI._handle_content_type_and_headers( - content_type or "application/octet-stream", - headers - ) res = self._session.put(url, data=data, - headers=self._get_headers(_headers), + headers=self._get_headers( + content_type or "application/octet-stream", + headers + ), verify=self.ssl_config.get_verify(), cert=self.ssl_config.get_cert(), timeout=self._timeout, @@ -362,12 +330,11 @@ def get(self, @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _get() -> Any: - _headers = AbstractAPI._handle_content_type_and_headers( - None, - headers - ) res = self._session.get(url, - headers=self._get_headers(_headers), + headers=self._get_headers( + None, + headers + ), verify=self.ssl_config.get_verify(), cert=self.ssl_config.get_cert(), timeout=self._timeout, @@ -399,14 +366,14 @@ def post(self, @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _post() -> Any: - _headers = AbstractAPI._handle_content_type_and_headers( + _headers = self._get_headers( content_type, headers ) if _headers["Content-Type"].startswith('application/json'): res = self._session.post(url, json=data, - headers=self._get_headers(_headers), + headers=_headers, verify=self.ssl_config.get_verify(), cert=self.ssl_config.get_cert(), timeout=self._timeout, @@ -414,7 +381,7 @@ def _post() -> Any: else: res = self._session.post(url, data=data, - headers=self._get_headers(_headers), + headers=_headers, verify=self.ssl_config.get_verify(), cert=self.ssl_config.get_cert(), timeout=self._timeout, @@ -447,14 +414,14 @@ def put(self, @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _put() -> Any: - _headers = AbstractAPI._handle_content_type_and_headers( + _headers = self._get_headers( content_type, headers ) if _headers["Content-Type"].startswith('application/json'): res = self._session.put(url, json=data, - headers=self._get_headers(_headers), + headers=_headers, verify=self.ssl_config.get_verify(), cert=self.ssl_config.get_cert(), timeout=self._timeout, @@ -462,7 +429,7 @@ def _put() -> Any: else: res = self._session.put(url, data=data, - headers=self._get_headers(_headers), + headers=_headers, verify=self.ssl_config.get_verify(), cert=self.ssl_config.get_cert(), timeout=self._timeout, @@ -495,14 +462,14 @@ def patch(self, @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _patch() -> Any: - _headers = AbstractAPI._handle_content_type_and_headers( + _headers = self._get_headers( content_type, headers ) if _headers["Content-Type"].startswith('application/json'): res = self._session.patch(url, json=data, - headers=self._get_headers(_headers), + headers=_headers, verify=self.ssl_config.get_verify(), cert=self.ssl_config.get_cert(), timeout=self._timeout, @@ -510,7 +477,7 @@ def _patch() -> Any: else: res = self._session.patch(url, data=data, - headers=self._get_headers(_headers), + headers=_headers, verify=self.ssl_config.get_verify(), cert=self.ssl_config.get_cert(), timeout=self._timeout, @@ -536,12 +503,11 @@ def delete(self, @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _delete() -> Any: - _headers = AbstractAPI._handle_content_type_and_headers( - None, - headers - ) res = self._session.delete(url, - headers=self._get_headers(_headers), + headers=self._get_headers( + None, + headers + ), verify=self.ssl_config.get_verify(), cert=self.ssl_config.get_cert(), timeout=self._timeout, @@ -579,17 +545,24 @@ def _merge_headers(headers: dict, override: dict) -> dict: return headers - def _get_headers(self, override: dict = None) -> dict: + def _get_headers(self, + content_type: Optional[str], + override: dict = None) -> dict: """ - Create a header dict for requests. Uses abstract method *self._handle_token()*. + Handle header merging and content_type parameter in basic requests below. - :param override: Dict of headers that override the internal headers. If a header key is set to value None, - it will be removed from the headers. - :return: A dict containing header values for requests. + :param content_type: The desired content_type for the call. This can explicitly set to None to erase it from the + headers in the final http call. This value will always be present under key 'Content-Type' in the + returned dict and overwrites every other specification in the header dicts unless set to None. + :param override: Optional external headers, overriding the initial headers. + :return: A dict with headers. Will always contain a key "Content-Type" - the value of which might be None. Will + also contain Authorization Bearer Token if available. """ - headers = AbstractAPI._merge_headers(self._headers.copy(), override) + if content_type: + headers.update({"Content-Type": content_type}) + token = self._handle_token() if token: headers['Authorization'] = "Bearer " + token @@ -1546,7 +1519,6 @@ class PasswordAuthTokenApiHandler(AbstractRemoteTokenApiHandler): def __init__(self, username: str = None, password: str = None, - connection_handler=None, *args, **kwargs): """ Constructor @@ -1565,7 +1537,7 @@ def __init__(self, :param args: Unnamed parameter passthrough for parent class. :param kwargs: Named parameter passthrough for parent class. """ - super().__init__(connection_handler=connection_handler, *args, **kwargs) + super().__init__(*args, **kwargs) self._username = username self._password = password @@ -1647,7 +1619,6 @@ def __init__(self, code: str = None, redirect_uri: str = None, code_verifier: str = None, - connection_handler=None, *args, **kwargs): """ Constructor @@ -1671,7 +1642,7 @@ def __init__(self, :param args: Unnamed parameter passthrough for parent class. :param kwargs: Named parameter passthrough for parent class. """ - super().__init__(connection_handler=connection_handler, *args, **kwargs) + super().__init__(*args, **kwargs) self._code = code self._code_verifier = code_verifier From 429b7c6dda04da9146791f258008eec74c113fc2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wolfgang=20H=C3=BCbner?= Date: Mon, 13 Jun 2022 11:51:18 +0200 Subject: [PATCH 4/6] Fixing content-type handling --- src/hiro_graph_client/clientlib.py | 90 +++++++++++++++++------------- 1 file changed, 52 insertions(+), 38 deletions(-) diff --git a/src/hiro_graph_client/clientlib.py b/src/hiro_graph_client/clientlib.py index f0be706..a94c512 100644 --- a/src/hiro_graph_client/clientlib.py +++ b/src/hiro_graph_client/clientlib.py @@ -214,7 +214,10 @@ def _capitalize_header(name: str) -> str: # Basic requests ############################################################################################################### - def get_binary(self, url: str, accept: str = None, headers: dict = None) -> Iterator[bytes]: + def get_binary(self, + url: str, + accept: str = None, + headers: dict = None) -> Iterator[bytes]: """ Implementation of GET for binary data. @@ -226,7 +229,10 @@ def get_binary(self, url: str, accept: str = None, headers: dict = None) -> Iter @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _get_binary() -> Iterator[bytes]: - _headers = self._get_headers(None, headers) + _headers = self._get_headers( + content_type=None, + override=headers, + remove_content_type=True) _headers.update({"Accept": (accept or "*/*")}) with self._session.get(url, @@ -248,7 +254,7 @@ def post_binary(self, url: str, data: Any, content_type: str = None, - expected_media_type: str = 'application/json', + expected_media_type: str = None, headers: dict = None) -> Any: """ Implementation of POST for binary data. @@ -267,15 +273,15 @@ def _post_binary() -> Any: res = self._session.post(url, data=data, headers=self._get_headers( - content_type or "application/octet-stream", - headers + content_type=content_type or "application/octet-stream", + override=headers ), verify=self.ssl_config.get_verify(), cert=self.ssl_config.get_cert(), timeout=self._timeout, proxies=self._get_proxies()) self._log_communication(res, request_body=False) - return self._parse_response(res, expected_media_type) + return self._parse_response(res, expected_media_type or 'application/json') return _post_binary() @@ -283,7 +289,7 @@ def put_binary(self, url: str, data: Any, content_type: str = None, - expected_media_type: str = 'application/json', + expected_media_type: str = None, headers: dict = None) -> Any: """ Implementation of PUT for binary data. @@ -302,21 +308,21 @@ def _put_binary() -> Any: res = self._session.put(url, data=data, headers=self._get_headers( - content_type or "application/octet-stream", - headers + content_type=content_type or "application/octet-stream", + override=headers ), verify=self.ssl_config.get_verify(), cert=self.ssl_config.get_cert(), timeout=self._timeout, proxies=self._get_proxies()) self._log_communication(res, request_body=False) - return self._parse_response(res, expected_media_type) + return self._parse_response(res, expected_media_type or 'application/json') return _put_binary() def get(self, url: str, - expected_media_type: str = 'application/json', + expected_media_type: str = None, headers: dict = None) -> Any: """ Implementation of GET @@ -332,23 +338,24 @@ def get(self, def _get() -> Any: res = self._session.get(url, headers=self._get_headers( - None, - headers + content_type=None, + override=headers, + remove_content_type=True ), verify=self.ssl_config.get_verify(), cert=self.ssl_config.get_cert(), timeout=self._timeout, proxies=self._get_proxies()) self._log_communication(res) - return self._parse_response(res, expected_media_type) + return self._parse_response(res, expected_media_type or 'application/json') return _get() def post(self, url: str, data: Any, - expected_media_type: str = 'application/json', - content_type: str = 'application/json', + expected_media_type: str = None, + content_type: str = None, headers: dict = None) -> Any: """ Implementation of POST @@ -367,8 +374,8 @@ def post(self, @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _post() -> Any: _headers = self._get_headers( - content_type, - headers + content_type=content_type or 'application/json', + override=headers ) if _headers["Content-Type"].startswith('application/json'): res = self._session.post(url, @@ -388,15 +395,15 @@ def _post() -> Any: proxies=self._get_proxies()) self._log_communication(res) - return self._parse_response(res, expected_media_type) + return self._parse_response(res, expected_media_type or 'application/json') return _post() def put(self, url: str, data: Any, - expected_media_type: str = 'application/json', - content_type: str = 'application/json', + expected_media_type: str = None, + content_type: str = None, headers: dict = None) -> Any: """ Implementation of PUT @@ -415,8 +422,8 @@ def put(self, @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _put() -> Any: _headers = self._get_headers( - content_type, - headers + content_type=content_type or 'application/json', + override=headers ) if _headers["Content-Type"].startswith('application/json'): res = self._session.put(url, @@ -436,15 +443,15 @@ def _put() -> Any: proxies=self._get_proxies()) self._log_communication(res) - return self._parse_response(res, expected_media_type) + return self._parse_response(res, expected_media_type or 'application/json') return _put() def patch(self, url: str, data: Any, - expected_media_type: str = 'application/json', - content_type: str = 'application/json', + expected_media_type: str = None, + content_type: str = None, headers: dict = None) -> Any: """ Implementation of PATCH @@ -463,8 +470,8 @@ def patch(self, @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _patch() -> Any: _headers = self._get_headers( - content_type, - headers + content_type=content_type or 'application/json', + override=headers ) if _headers["Content-Type"].startswith('application/json'): res = self._session.patch(url, @@ -483,13 +490,13 @@ def _patch() -> Any: timeout=self._timeout, proxies=self._get_proxies()) self._log_communication(res) - return self._parse_response(res, expected_media_type) + return self._parse_response(res, expected_media_type or 'application/json') return _patch() def delete(self, url: str, - expected_media_type: str = 'application/json', + expected_media_type: str = None, headers: dict = None) -> Any: """ Implementation of DELETE @@ -505,15 +512,16 @@ def delete(self, def _delete() -> Any: res = self._session.delete(url, headers=self._get_headers( - None, - headers + content_type=None, + override=headers, + remove_content_type=True ), verify=self.ssl_config.get_verify(), cert=self.ssl_config.get_cert(), timeout=self._timeout, proxies=self._get_proxies()) self._log_communication(res) - return self._parse_response(res, expected_media_type) + return self._parse_response(res, expected_media_type or 'application/json') return _delete() @@ -547,14 +555,16 @@ def _merge_headers(headers: dict, override: dict) -> dict: def _get_headers(self, content_type: Optional[str], - override: dict = None) -> dict: + override: dict = None, + remove_content_type: bool = None) -> dict: """ Handle header merging and content_type parameter in basic requests below. - :param content_type: The desired content_type for the call. This can explicitly set to None to erase it from the - headers in the final http call. This value will always be present under key 'Content-Type' in the - returned dict and overwrites every other specification in the header dicts unless set to None. + :param content_type: The desired content_type for the call. This value will be present under key + 'Content-Type' in the returned dict and overwrites every other specification in the header dicts unless + set to None. :param override: Optional external headers, overriding the initial headers. + :param remove_content_type: Explicitly remove the key "Content-Type" from the returned header map. :return: A dict with headers. Will always contain a key "Content-Type" - the value of which might be None. Will also contain Authorization Bearer Token if available. """ @@ -562,6 +572,8 @@ def _get_headers(self, if content_type: headers.update({"Content-Type": content_type}) + if remove_content_type: + headers.pop("Content-Type") token = self._handle_token() if token: @@ -596,7 +608,7 @@ def _get_query_part(params: dict) -> str: def _parse_response(self, res: requests.Response, - expected_media_type: str = 'application/json') -> Any: + expected_media_type: str = None) -> Any: """ Parse the response of the backend. @@ -612,6 +624,8 @@ def _parse_response(self, try: self._check_response(res) self._check_status_error(res) + if not expected_media_type: + expected_media_type = 'application/json' if expected_media_type not in ['*', '*/*']: AbstractAPI._check_content_type(res, expected_media_type) if expected_media_type.lower() == 'application/json': From c44daf1e34d48990c654fb9a32277499e103cccc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wolfgang=20H=C3=BCbner?= Date: Fri, 17 Jun 2022 14:50:52 +0200 Subject: [PATCH 5/6] Generate authorization url within --- src/hiro_graph_client/clientlib.py | 64 ++++++++++++++++++++------ src/hiro_graph_client/requirements.txt | 3 +- src/setup.py | 3 +- 3 files changed, 55 insertions(+), 15 deletions(-) diff --git a/src/hiro_graph_client/clientlib.py b/src/hiro_graph_client/clientlib.py index a94c512..4ed6878 100644 --- a/src/hiro_graph_client/clientlib.py +++ b/src/hiro_graph_client/clientlib.py @@ -3,6 +3,7 @@ import json import logging import os +import secrets import threading import time import urllib @@ -11,6 +12,7 @@ from urllib.parse import quote, urlencode import backoff +import pkce import requests import requests.adapters @@ -598,12 +600,12 @@ def _bool_to_external_str(value: Any) -> Optional[str]: @staticmethod def _get_query_part(params: dict) -> str: """ - Create the query part of an url. Keys in *params* whose values are set to None are removed. + Create the query part of an url. Keys in *params* whose values are set to None or empty are removed. :param params: A dict of params to use for the query. :return: The query part of an url with a leading '?', or an empty string when query is empty. """ - params_cleaned = {k: AbstractAPI._bool_to_external_str(v) for k, v in params.items() if v is not None} + params_cleaned = {k: AbstractAPI._bool_to_external_str(v) for k, v in params.items() if v} return ('?' + urlencode(params_cleaned, quote_via=quote, safe="/,")) if params_cleaned else "" def _parse_response(self, @@ -1625,14 +1627,15 @@ class CodeFlowAuthTokenApiHandler(AbstractRemoteTokenApiHandler): It is built this way to avoid endless calling loops when resolving tokens. """ - _code: str - _code_verifier: str + _scope: str _redirect_uri: str + _code_verifier: str + _code: str + _state: str def __init__(self, - code: str = None, - redirect_uri: str = None, - code_verifier: str = None, + redirect_uri: str, + scope: str = None, *args, **kwargs): """ Constructor @@ -1643,10 +1646,8 @@ def __init__(self, See parent :class:`AbstractRemoteTokenApiHandler` for a full description of all remaining parameters. - :param code: One time code received from the authorization server. - :param redirect_uri: The original redirect_uri parameter from the authorization-redirect call (first call of the - code flow not handled here). - :param code_verifier: The code_verifier for the PKCE code flow. + :param redirect_uri: The redirect_uri parameter for the authorization-redirect call. + :param scope: Optional scope for OAuth login. :param client_id: OAuth client_id for authentication :param client_secret: OAuth client_secret for authentication. This is optional here. :param organization: Optional name of an organization to be used for the token requests. @@ -1658,10 +1659,47 @@ def __init__(self, """ super().__init__(*args, **kwargs) - self._code = code - self._code_verifier = code_verifier + self._scope = scope self._redirect_uri = redirect_uri + self._state = secrets.token_urlsafe(16) + self._code_verifier = pkce.generate_code_verifier(length=64) + + def get_authorize_uri(self) -> str: + """ + Construct an authorization uri for your browser. + + :return: The uri for the browser. + """ + if not self._redirect_uri: + raise AuthenticationTokenError("redirect_uri is missing", 400) + + url = self.endpoint + "/authorize" + + data = { + "response_type": "code", + "client_id": self._client_id, + "redirect_uri": self._redirect_uri, + "code_challenge": pkce.get_code_challenge(self._code_verifier), + "code_challenge_method": "S256", + "state": self._state, + "scope": self._scope + } + + return url + self._get_query_part(data) + + def handle_authorize_callback(self, state: str, code: str) -> None: + """ + Saves the code and matches the state from the authorization callback. + + :param state: The state returned by the callback. Must not have changed. + :param code: The one-time-code returned (used to obtain a full authorization token). + """ + if state != self._state: + raise AuthenticationTokenError("The parameter 'state' of the callback does not match.", 400) + + self._code = code + def get_token(self, organization: str = None, organization_id: str = None) -> None: """ Construct a request to obtain a new token. This is the second step of the authorization_code flow. diff --git a/src/hiro_graph_client/requirements.txt b/src/hiro_graph_client/requirements.txt index 38a1c6a..d1573b1 100644 --- a/src/hiro_graph_client/requirements.txt +++ b/src/hiro_graph_client/requirements.txt @@ -4,4 +4,5 @@ backoff==2.0.1 setuptools==62.1.0 websocket-client==1.3.2 -apscheduler==3.9.1 \ No newline at end of file +apscheduler==3.9.1 +pkce==1.0.3 \ No newline at end of file diff --git a/src/setup.py b/src/setup.py index e95ce12..6c75d7e 100644 --- a/src/setup.py +++ b/src/setup.py @@ -31,7 +31,8 @@ 'requests', 'backoff', 'websocket-client', - 'apscheduler' + 'apscheduler', + 'pkce' ], package_data={ name: ['VERSION'] From f9bb91fcbe45c23702d0faf47e3dcaa611d3810d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wolfgang=20H=C3=BCbner?= Date: Mon, 20 Jun 2022 08:39:54 +0200 Subject: [PATCH 6/6] Initialize state and code_verifier anew * Each call to get_authorize_url shall use a new state and code_verifier. --- src/hiro_graph_client/clientlib.py | 7 +- tests/unit/test_client.py | 113 +++++++++++++++++++++++------ 2 files changed, 94 insertions(+), 26 deletions(-) diff --git a/src/hiro_graph_client/clientlib.py b/src/hiro_graph_client/clientlib.py index 4ed6878..2f0adf0 100644 --- a/src/hiro_graph_client/clientlib.py +++ b/src/hiro_graph_client/clientlib.py @@ -1662,9 +1662,6 @@ def __init__(self, self._scope = scope self._redirect_uri = redirect_uri - self._state = secrets.token_urlsafe(16) - self._code_verifier = pkce.generate_code_verifier(length=64) - def get_authorize_uri(self) -> str: """ Construct an authorization uri for your browser. @@ -1676,6 +1673,10 @@ def get_authorize_uri(self) -> str: url = self.endpoint + "/authorize" + # Initialize state and code_verifier anew with each call. + self._state = secrets.token_urlsafe(16) + self._code_verifier = pkce.generate_code_verifier(length=64) + data = { "response_type": "code", "client_id": self._client_id, diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index ca7a1ee..633d507 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -1,40 +1,106 @@ -from hiro_graph_client import PasswordAuthTokenApiHandler, HiroGraph, SSLConfig, GraphConnectionHandler -from .testconfig import CONFIG +import base64 +import gzip +import json + +from hiro_graph_client import CodeFlowAuthTokenApiHandler, PasswordAuthTokenApiHandler, SSLConfig, \ + GraphConnectionHandler, HiroApp, HiroAuth +from .testconfig import CONFIG_STAGE as CONFIG class TestClient: + connection_handler = GraphConnectionHandler( root_url=CONFIG.get('URL'), ssl_config=SSLConfig(verify=False) ) - hiro_api_handler = PasswordAuthTokenApiHandler( - username=CONFIG.get('USERNAME'), - password=CONFIG.get('PASSWORD'), - client_id=CONFIG.get('CLIENT_ID'), - client_secret=CONFIG.get('CLIENT_SECRET'), - secure_logging=False, - connection_handler=connection_handler - ) + def test_download_config(self): + hiro_api_handler = PasswordAuthTokenApiHandler( + username=CONFIG.get('USERNAME'), + password=CONFIG.get('PASSWORD'), + client_id=CONFIG.get('CLIENT_ID'), + client_secret=CONFIG.get('CLIENT_SECRET'), + secure_logging=False, + connection_handler=self.connection_handler + ) - hiro_api_handler2 = PasswordAuthTokenApiHandler( - username=CONFIG.get('USERNAME'), - password=CONFIG.get('PASSWORD'), - client_id=CONFIG.get('CLIENT_ID'), - client_secret=CONFIG.get('CLIENT_SECRET'), - secure_logging=False, - connection_handler=connection_handler - ) + hiro_app: HiroApp = HiroApp(api_handler=hiro_api_handler) - def test_simple_query(self): - hiro_client: HiroGraph = HiroGraph(api_handler=self.hiro_api_handler) + result = hiro_app.get_config() + + content = result.get("content") + content_type = result.get("type") + + content_bytes = base64.b64decode(content) if "base64" in content_type else bytearray(content, encoding='utf8') + + config_data = str(gzip.decompress(content_bytes), encoding='utf8') if "gz" in content_type else str( + content_bytes, encoding='utf8') + + with open("connector.conf", "w") as file: + print(config_data, file=file) + + def test_connect(self): + hiro_api_handler = PasswordAuthTokenApiHandler( + username=CONFIG.get('USERNAME'), + password=CONFIG.get('PASSWORD'), + client_id=CONFIG.get('CLIENT_ID'), + client_secret=CONFIG.get('CLIENT_SECRET'), + secure_logging=False, + connection_handler=self.connection_handler + ) + + hiro_auth: HiroAuth = HiroAuth(api_handler=hiro_api_handler) + + print(json.dumps(hiro_auth.get_identity(), indent=2)) - hiro_client.get_node(node_id="ckqjkt42s0fgf0883pf0cb0hx_ckqjl014l0hvr0883hxcvmcwq", meta=True) + def test_refresh_token(self): + hiro_api_handler = PasswordAuthTokenApiHandler( + username=CONFIG.get('USERNAME'), + password=CONFIG.get('PASSWORD'), + client_id=CONFIG.get('CLIENT_ID'), + client_secret=CONFIG.get('CLIENT_SECRET'), + secure_logging=False, + connection_handler=self.connection_handler + ) - hiro_client: HiroGraph = HiroGraph(api_handler=self.hiro_api_handler2) + handler = hiro_api_handler - hiro_client.get_node(node_id="ckqjkt42s0fgf0883pf0cb0hx_ckqjl014l0hvr0883hxcvmcwq", meta=True) + handler.get_token() + handler.refresh_token() + hiro_auth: HiroAuth = HiroAuth(api_handler=handler) + + print(json.dumps(hiro_auth.get_identity(), indent=2)) + + handler.revoke_token() + + print(json.dumps(hiro_auth.get_identity(), indent=2)) + + def test_auth_code_flow(self): + hiro_api_handler = CodeFlowAuthTokenApiHandler( + client_id=CONFIG.get('CLIENT_ID'), + code='4b12c5f1-cdd7-3487-a5b0-82305b011b67', + code_verifier='BQ3kyADe6RKYHttFwGPwvLrX_B6zwCr2vNRe00TQLfoRo-HYhUqM8sPKUQlkhbwUQxli2ZneFhFqx4xbltI4WQ', + redirect_uri='http://wksw-whuebner.arago.de:8080/oauth2/app/callback.xhtml', + secure_logging=False, + connection_handler=self.connection_handler + ) + + hiro_api_handler.get_token() + print(json.dumps(hiro_api_handler.decode_token(), indent=2)) + + hiro_api_handler.refresh_token() + print(json.dumps(hiro_api_handler.decode_token(), indent=2)) + + def test_simple_query(self): + # hiro_client: HiroGraph = HiroGraph(api_handler=self.hiro_api_handler) + # + # hiro_client.get_node(node_id="ckqjkt42s0fgf0883pf0cb0hx_ckqjl014l0hvr0883hxcvmcwq", meta=True) + # + # hiro_client: HiroGraph = HiroGraph(api_handler=self.hiro_api_handler2) + # + # hiro_client.get_node(node_id="ckqjkt42s0fgf0883pf0cb0hx_ckqjl014l0hvr0883hxcvmcwq", meta=True) + # # def query(_id: str): # hiro_client.get_node(node_id=_id) # @@ -45,3 +111,4 @@ def test_simple_query(self): # t1.start() # t2.start() # t3.start() + pass