diff --git a/src/hiro_graph_client/__init__.py b/src/hiro_graph_client/__init__.py index 9e05f8d..f80ad3d 100644 --- a/src/hiro_graph_client/__init__.py +++ b/src/hiro_graph_client/__init__.py @@ -9,8 +9,8 @@ from hiro_graph_client.authzclient import HiroAuthz from hiro_graph_client.client import HiroGraph from hiro_graph_client.clientlib import AbstractTokenApiHandler, GraphConnectionHandler, AuthenticationTokenError, \ - FixedTokenError, TokenUnauthorizedError, PasswordAuthTokenApiHandler, FixedTokenApiHandler, \ - EnvironmentTokenApiHandler, SSLConfig + FixedTokenError, TokenUnauthorizedError, PasswordAuthTokenApiHandler, CodeFlowAuthTokenApiHandler, \ + FixedTokenApiHandler, EnvironmentTokenApiHandler, SSLConfig from hiro_graph_client.iamclient import HiroIam from hiro_graph_client.kiclient import HiroKi from hiro_graph_client.variablesclient import HiroVariables @@ -20,9 +20,9 @@ __all__ = [ 'HiroGraph', 'HiroAuth', 'HiroApp', 'HiroIam', 'HiroKi', 'HiroAuthz', 'HiroVariables', 'GraphConnectionHandler', - 'AbstractTokenApiHandler', 'PasswordAuthTokenApiHandler', 'FixedTokenApiHandler', 'EnvironmentTokenApiHandler', - 'AuthenticationTokenError', 'FixedTokenError', 'TokenUnauthorizedError', '__version__', - 'SSLConfig' + 'AbstractTokenApiHandler', 'PasswordAuthTokenApiHandler', 'CodeFlowAuthTokenApiHandler', 'FixedTokenApiHandler', + 'EnvironmentTokenApiHandler', 'AuthenticationTokenError', 'FixedTokenError', 'TokenUnauthorizedError', + '__version__', 'SSLConfig' ] site.addsitedir(this_directory) diff --git a/src/hiro_graph_client/clientlib.py b/src/hiro_graph_client/clientlib.py index 96eac8c..2f0adf0 100644 --- a/src/hiro_graph_client/clientlib.py +++ b/src/hiro_graph_client/clientlib.py @@ -3,6 +3,7 @@ import json import logging import os +import secrets import threading import time import urllib @@ -11,6 +12,7 @@ from urllib.parse import quote, urlencode import backoff +import pkce import requests import requests.adapters @@ -172,7 +174,6 @@ def __init__(self, max_tries = abstract_api._max_tries else: initial_headers = { - 'Content-Type': 'application/json', 'Accept': 'text/plain, application/json', 'User-Agent': f"{client_name or self._client_name} {__version__}" } @@ -215,21 +216,29 @@ def _capitalize_header(name: str) -> str: # Basic requests ############################################################################################################### - def get_binary(self, url: str, accept: str = None) -> Iterator[bytes]: + def get_binary(self, + url: str, + accept: str = None, + headers: dict = None) -> Iterator[bytes]: """ Implementation of GET for binary data. :param url: Url to use :param accept: Mimetype for accept. Will be set to */* if not given. + :param headers: Optional additional headers. :return: Yields over raw chunks of the response payload. """ @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _get_binary() -> Iterator[bytes]: + _headers = self._get_headers( + content_type=None, + override=headers, + remove_content_type=True) + _headers.update({"Accept": (accept or "*/*")}) + with self._session.get(url, - headers=self._get_headers( - {"Content-Type": None, "Accept": (accept or "*/*")} - ), + headers=_headers, verify=self.ssl_config.get_verify(), cert=self.ssl_config.get_cert(), timeout=self._timeout, @@ -247,7 +256,8 @@ def post_binary(self, url: str, data: Any, content_type: str = None, - expected_media_type: str = 'application/json') -> Any: + expected_media_type: str = None, + headers: dict = None) -> Any: """ Implementation of POST for binary data. @@ -256,6 +266,7 @@ def post_binary(self, :param content_type: The content type of the data. Defaults to "application/octet-stream" internally if unset. :param expected_media_type: The expected media type. Default is 'application/json'. If this is set to '*' or '*/*', any media_type is accepted. + :param headers: Optional additional headers. :return: The payload of the response """ @@ -264,14 +275,15 @@ def _post_binary() -> Any: res = self._session.post(url, data=data, headers=self._get_headers( - {"Content-Type": (content_type or "application/octet-stream")} + content_type=content_type or "application/octet-stream", + override=headers ), verify=self.ssl_config.get_verify(), cert=self.ssl_config.get_cert(), timeout=self._timeout, proxies=self._get_proxies()) self._log_communication(res, request_body=False) - return self._parse_response(res, expected_media_type) + return self._parse_response(res, expected_media_type or 'application/json') return _post_binary() @@ -279,7 +291,8 @@ def put_binary(self, url: str, data: Any, content_type: str = None, - expected_media_type: str = 'application/json') -> Any: + expected_media_type: str = None, + headers: dict = None) -> Any: """ Implementation of PUT for binary data. @@ -288,6 +301,7 @@ def put_binary(self, :param content_type: The content type of the data. Defaults to "application/octet-stream" internally if unset. :param expected_media_type: The expected media type. Default is 'application/json'. If this is set to '*' or '*/*', any media_type is accepted. + :param headers: Optional additional headers. :return: The payload of the response """ @@ -296,46 +310,55 @@ def _put_binary() -> Any: res = self._session.put(url, data=data, headers=self._get_headers( - {"Content-Type": (content_type or "application/octet-stream")} + content_type=content_type or "application/octet-stream", + override=headers ), verify=self.ssl_config.get_verify(), cert=self.ssl_config.get_cert(), timeout=self._timeout, proxies=self._get_proxies()) self._log_communication(res, request_body=False) - return self._parse_response(res, expected_media_type) + return self._parse_response(res, expected_media_type or 'application/json') return _put_binary() def get(self, url: str, - expected_media_type: str = 'application/json') -> Any: + expected_media_type: str = None, + headers: dict = None) -> Any: """ Implementation of GET :param url: Url to use :param expected_media_type: The expected media type. Default is 'application/json'. If this is set to '*' or '*/*', any media_type is accepted. + :param headers: Optional additional headers. :return: The payload of the response """ @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _get() -> Any: res = self._session.get(url, - headers=self._get_headers({"Content-Type": None}), + headers=self._get_headers( + content_type=None, + override=headers, + remove_content_type=True + ), verify=self.ssl_config.get_verify(), cert=self.ssl_config.get_cert(), timeout=self._timeout, proxies=self._get_proxies()) self._log_communication(res) - return self._parse_response(res, expected_media_type) + return self._parse_response(res, expected_media_type or 'application/json') return _get() def post(self, url: str, data: Any, - expected_media_type: str = 'application/json') -> Any: + expected_media_type: str = None, + content_type: str = None, + headers: dict = None) -> Any: """ Implementation of POST @@ -343,27 +366,47 @@ def post(self, :param data: The payload to POST :param expected_media_type: The expected media type. Default is 'application/json'. If this is set to '*' or '*/*', any media_type is accepted. + :param content_type: The content type to send. For data of type dict: If 'application/json' is used, the request + will be using JSON formatting (requests parameter json=), x-www-form-urlencoded will be used otherwise + (requests parameter data=). Default is application/json. + :param headers: Optional additional headers. :return: The payload of the response """ @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _post() -> Any: - res = self._session.post(url, - json=data, - headers=self._get_headers(), - verify=self.ssl_config.get_verify(), - cert=self.ssl_config.get_cert(), - timeout=self._timeout, - proxies=self._get_proxies()) + _headers = self._get_headers( + content_type=content_type or 'application/json', + override=headers + ) + if _headers["Content-Type"].startswith('application/json'): + res = self._session.post(url, + json=data, + headers=_headers, + verify=self.ssl_config.get_verify(), + cert=self.ssl_config.get_cert(), + timeout=self._timeout, + proxies=self._get_proxies()) + else: + res = self._session.post(url, + data=data, + headers=_headers, + verify=self.ssl_config.get_verify(), + cert=self.ssl_config.get_cert(), + timeout=self._timeout, + proxies=self._get_proxies()) + self._log_communication(res) - return self._parse_response(res, expected_media_type) + return self._parse_response(res, expected_media_type or 'application/json') return _post() def put(self, url: str, data: Any, - expected_media_type: str = 'application/json') -> Any: + expected_media_type: str = None, + content_type: str = None, + headers: dict = None) -> Any: """ Implementation of PUT @@ -371,27 +414,47 @@ def put(self, :param data: The payload to PUT :param expected_media_type: The expected media type. Default is 'application/json'. If this is set to '*' or '*/*', any media_type is accepted. + :param content_type: The content type to send. For data of type dict: If 'application/json' is used, the request + will be using JSON formatting (requests parameter json=), x-www-form-urlencoded will be used otherwise + (requests parameter data=). Default is application/json. + :param headers: Optional additional headers. :return: The payload of the response """ @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _put() -> Any: - res = self._session.put(url, - json=data, - headers=self._get_headers(), - verify=self.ssl_config.get_verify(), - cert=self.ssl_config.get_cert(), - timeout=self._timeout, - proxies=self._get_proxies()) + _headers = self._get_headers( + content_type=content_type or 'application/json', + override=headers + ) + if _headers["Content-Type"].startswith('application/json'): + res = self._session.put(url, + json=data, + headers=_headers, + verify=self.ssl_config.get_verify(), + cert=self.ssl_config.get_cert(), + timeout=self._timeout, + proxies=self._get_proxies()) + else: + res = self._session.put(url, + data=data, + headers=_headers, + verify=self.ssl_config.get_verify(), + cert=self.ssl_config.get_cert(), + timeout=self._timeout, + proxies=self._get_proxies()) + self._log_communication(res) - return self._parse_response(res, expected_media_type) + return self._parse_response(res, expected_media_type or 'application/json') return _put() def patch(self, url: str, data: Any, - expected_media_type: str = 'application/json') -> Any: + expected_media_type: str = None, + content_type: str = None, + headers: dict = None) -> Any: """ Implementation of PATCH @@ -399,45 +462,68 @@ def patch(self, :param data: The payload to PUT :param expected_media_type: The expected media type. Default is 'application/json'. If this is set to '*' or '*/*', any media_type is accepted. + :param content_type: The content type to send. For data of type dict: If 'application/json' is used, the request + will be using JSON formatting (requests parameter json=), x-www-form-urlencoded will be used otherwise + (requests parameter data=). Default is application/json. + :param headers: Optional additional headers. :return: The payload of the response """ @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _patch() -> Any: - res = self._session.patch(url, - json=data, - headers=self._get_headers(), - verify=self.ssl_config.get_verify(), - cert=self.ssl_config.get_cert(), - timeout=self._timeout, - proxies=self._get_proxies()) + _headers = self._get_headers( + content_type=content_type or 'application/json', + override=headers + ) + if _headers["Content-Type"].startswith('application/json'): + res = self._session.patch(url, + json=data, + headers=_headers, + verify=self.ssl_config.get_verify(), + cert=self.ssl_config.get_cert(), + timeout=self._timeout, + proxies=self._get_proxies()) + else: + res = self._session.patch(url, + data=data, + headers=_headers, + verify=self.ssl_config.get_verify(), + cert=self.ssl_config.get_cert(), + timeout=self._timeout, + proxies=self._get_proxies()) self._log_communication(res) - return self._parse_response(res, expected_media_type) + return self._parse_response(res, expected_media_type or 'application/json') return _patch() def delete(self, url: str, - expected_media_type: str = 'application/json') -> Any: + expected_media_type: str = None, + headers: dict = None) -> Any: """ Implementation of DELETE :param url: Url to use :param expected_media_type: The expected media type. Default is 'application/json'. If this is set to '*' or '*/*', any media_type is accepted. + :param headers: Optional additional headers. :return: The payload of the response """ @backoff.on_exception(*BACKOFF_ARGS, **BACKOFF_KWARGS, max_tries=self._get_max_tries) def _delete() -> Any: res = self._session.delete(url, - headers=self._get_headers({"Content-Type": None}), + headers=self._get_headers( + content_type=None, + override=headers, + remove_content_type=True + ), verify=self.ssl_config.get_verify(), cert=self.ssl_config.get_cert(), timeout=self._timeout, proxies=self._get_proxies()) self._log_communication(res) - return self._parse_response(res, expected_media_type) + return self._parse_response(res, expected_media_type or 'application/json') return _delete() @@ -469,17 +555,28 @@ def _merge_headers(headers: dict, override: dict) -> dict: return headers - def _get_headers(self, override: dict = None) -> dict: + def _get_headers(self, + content_type: Optional[str], + override: dict = None, + remove_content_type: bool = None) -> dict: """ - Create a header dict for requests. Uses abstract method *self._handle_token()*. + Handle header merging and content_type parameter in basic requests below. - :param override: Dict of headers that override the internal headers. If a header key is set to value None, - it will be removed from the headers. - :return: A dict containing header values for requests. + :param content_type: The desired content_type for the call. This value will be present under key + 'Content-Type' in the returned dict and overwrites every other specification in the header dicts unless + set to None. + :param override: Optional external headers, overriding the initial headers. + :param remove_content_type: Explicitly remove the key "Content-Type" from the returned header map. + :return: A dict with headers. Will always contain a key "Content-Type" - the value of which might be None. Will + also contain Authorization Bearer Token if available. """ - headers = AbstractAPI._merge_headers(self._headers.copy(), override) + if content_type: + headers.update({"Content-Type": content_type}) + if remove_content_type: + headers.pop("Content-Type") + token = self._handle_token() if token: headers['Authorization'] = "Bearer " + token @@ -503,17 +600,17 @@ def _bool_to_external_str(value: Any) -> Optional[str]: @staticmethod def _get_query_part(params: dict) -> str: """ - Create the query part of an url. Keys in *params* whose values are set to None are removed. + Create the query part of an url. Keys in *params* whose values are set to None or empty are removed. :param params: A dict of params to use for the query. :return: The query part of an url with a leading '?', or an empty string when query is empty. """ - params_cleaned = {k: AbstractAPI._bool_to_external_str(v) for k, v in params.items() if v is not None} + params_cleaned = {k: AbstractAPI._bool_to_external_str(v) for k, v in params.items() if v} return ('?' + urlencode(params_cleaned, quote_via=quote, safe="/,")) if params_cleaned else "" def _parse_response(self, res: requests.Response, - expected_media_type: str = 'application/json') -> Any: + expected_media_type: str = None) -> Any: """ Parse the response of the backend. @@ -529,6 +626,8 @@ def _parse_response(self, try: self._check_response(res) self._check_status_error(res) + if not expected_media_type: + expected_media_type = 'application/json' if expected_media_type not in ['*', '*/*']: AbstractAPI._check_content_type(res, expected_media_type) if expected_media_type.lower() == 'application/json': @@ -966,9 +1065,15 @@ def decode_token_ext(token: str) -> dict: return dict(json.loads(json_payload)) @abstractmethod - def refresh_token(self) -> None: + def refresh_token(self, organization: str = None, organization_id: str = None) -> None: """ Refresh the current token. + + Organization information does not overwrite the internal organization or organization_id. + + :param organization: Optional name of an organization to be used for the token. + :param organization_id: Optional id of an organization to be used for the token. Overrides + parameter *organization*. """ raise RuntimeError('Cannot use method of this abstract class.') @@ -1019,7 +1124,7 @@ def __init__(self, token: str = None, *args, **kwargs): def token(self) -> str: return self._token - def refresh_token(self) -> None: + def refresh_token(self, organization: str = None, organization_id: str = None) -> None: raise FixedTokenError('Token is invalid and cannot be changed because it has been given externally.') def revoke_token(self, token_hint: str = "revoke_token") -> None: @@ -1060,7 +1165,7 @@ def __init__(self, env_var: str = 'HIRO_TOKEN', *args, **kwargs): def token(self) -> str: return os.environ[self._env_var] - def refresh_token(self) -> None: + def refresh_token(self, organization: str = None, organization_id: str = None) -> None: raise FixedTokenError( "Token is invalid and cannot be changed because it has been given as environment variable '{}'" " externally.".format(self._env_var)) @@ -1187,10 +1292,9 @@ def clear_token_data(self, access_token_only: bool): self.expires_at = -1 -class PasswordAuthTokenApiHandler(AbstractTokenApiHandler): +class AbstractRemoteTokenApiHandler(AbstractTokenApiHandler): """ - API Tokens will be fetched using this class. It does not handle any automatic token fetching, refresh or token - expiry. This has to be checked and triggered by the *caller*. + Remote API Tokens will be fetched using this class. The methods of this class are thread-safe, so it can be shared between several HIRO objects. @@ -1203,19 +1307,20 @@ class PasswordAuthTokenApiHandler(AbstractTokenApiHandler): _lock: threading.RLock """Reentrant mutex for thread safety""" - _username: str - _password: str _client_id: str _client_secret: str + _organization: str + _organization_id: str + _secure_logging: bool = True """Avoid logging of sensitive data.""" def __init__(self, - username: str = None, - password: str = None, client_id: str = None, client_secret: str = None, + organization: str = None, + organization_id: str = None, secure_logging: bool = True, *args, **kwargs): """ @@ -1224,20 +1329,21 @@ def __init__(self, See parent :class:`AbstractTokenApiHandler` for a full description of all remaining parameters. - :param username: Username for authentication - :param password: Password for authentication :param client_id: OAuth client_id for authentication - :param client_secret: OAuth client_secret for authentication + :param client_secret: OAuth client_secret for authentication (This can be None in special cases). + :param organization: Optional name of an organization to be used for the token requests. + :param organization_id: Optional id of an organization to be used for the token requests. Overrides + parameter *organization*. :param secure_logging: If this is enabled, payloads that might contain sensitive information are not logged. - :param args: Unnamed parameter passthrough for parent class. - :param kwargs: Named parameter passthrough for parent class. + :param args: Unnamed parameter passthrough for parent class. + :param kwargs: Named parameter passthrough for parent class. """ super().__init__(*args, **kwargs) - self._username = username - self._password = password self._client_id = client_id self._client_secret = client_secret + self._organization = organization + self._organization_id = organization_id self._secure_logging = secure_logging @@ -1274,44 +1380,29 @@ def _log_communication(self, res: requests.Response, request_body: bool = True, super()._log_communication(res, request_body=log_request_body, response_body=log_response_body) - def get_token(self) -> None: + @abstractmethod + def get_token(self, organization: str = None, organization_id: str = None) -> None: """ - Construct a request to obtain a new token. API self._endpoint + '/app' + Construct a request to obtain a new token. + + :param organization: Optional name of an organization to be used for this token. + :param organization_id: Optional id of an organization to be used for this token. Overrides + parameter *organization*. :raises AuthenticationTokenError: When no auth_endpoint is set. """ - with self._lock: - if not self.endpoint: - raise AuthenticationTokenError( - 'Token is invalid and endpoint (auth_endpoint) for obtaining is not set.') - - if not self._username or not self._password or not self._client_id or not self._client_secret: - msg = "" - if not self._username: - msg += "'username'" - if not self._password: - msg += (", " if msg else "") + "'password'" - if not self._client_id: - msg += (", " if msg else "") + "'client_id'" - if not self._client_secret: - msg += (", " if msg else "") + "'client_secret'" - raise AuthenticationTokenError( - "{} is missing required parameter(s) {}.".format(self.__class__.__name__, msg)) + raise RuntimeError('Cannot use method of this abstract class.') - url = self.endpoint + '/app' - data = { - "client_id": self._client_id, - "client_secret": self._client_secret, - "username": self._username, - "password": self._password - } + def refresh_token(self, organization: str = None, organization_id: str = None) -> None: + """ + Construct a request to refresh an existing token. - res = self.post(url, data) - self._token_info.parse_token_result(res, "{}.get_token".format(self.__class__.__name__)) + API self._endpoint + '/refresh'. (until /api/auth/6.5) + API self._endpoint + '/token'. (since /api/auth/6.6) - def refresh_token(self) -> None: - """ - Construct a request to refresh an existing token. API self._endpoint + '/refresh'. + :param organization: Optional name of an organization to be used for this token. + :param organization_id: Optional id of an organization to be used for this token. Overrides + parameter *organization*. :raises AuthenticationTokenError: When no auth_endpoint is set. """ @@ -1321,21 +1412,38 @@ def refresh_token(self) -> None: 'Token is invalid and endpoint (auth_endpoint) for refresh is not set.') if not self._token_info.refresh_token: - self.get_token() + self.get_token(organization=organization, organization_id=organization_id) return - url = self.endpoint + '/refresh' - data = { + auth_api_version = float(self._version_info['auth']['version']) + + if auth_api_version >= 6.6: + url = self.endpoint + '/token' + content_type = 'application/x-www-form-urlencoded' + data = { + "grant_type": "refresh_token" + } + else: + url = self.endpoint + '/refresh' + content_type = 'application/json' + data = {} + + self._organization = organization + self._organization_id = organization_id + + data.update({ "client_id": self._client_id, "client_secret": self._client_secret, - "refresh_token": self._token_info.refresh_token - } + "refresh_token": self._token_info.refresh_token, + "organization": self._organization, + "organization_id": self._organization_id + }) try: - res = self.post(url, data) + res = self.post(url, data, content_type=content_type) self._token_info.parse_token_result(res, "{}.refresh_token".format(self.__class__.__name__)) except AuthenticationTokenError: - self.get_token() + self.get_token(organization=organization, organization_id=organization_id) def revoke_token(self, token_hint: str = "refresh_token") -> None: """ @@ -1413,6 +1521,267 @@ def _handle_token(self) -> Optional[str]: return None +class PasswordAuthTokenApiHandler(AbstractRemoteTokenApiHandler): + """ + Implements the OAuth2 password auth flow. + + The methods of this class are thread-safe, so it can be shared between several HIRO objects. + + It is built this way to avoid endless calling loops when resolving tokens. + """ + _username: str + _password: str + + def __init__(self, + username: str = None, + password: str = None, + *args, **kwargs): + """ + Constructor + + See parent :class:`AbstractRemoteTokenApiHandler` for a full description + of all remaining parameters. + + :param username: Username for authentication + :param password: Password for authentication + :param client_id: OAuth client_id for authentication + :param client_secret: OAuth client_secret for authentication + :param organization: Optional name of an organization to be used for the token requests. + :param organization_id: Optional id of an organization to be used for the token requests. Overrides + parameter *organization*. + :param secure_logging: If this is enabled, payloads that might contain sensitive information are not logged. + :param args: Unnamed parameter passthrough for parent class. + :param kwargs: Named parameter passthrough for parent class. + """ + super().__init__(*args, **kwargs) + + self._username = username + self._password = password + + def get_token(self, organization: str = None, organization_id: str = None) -> None: + """ + Construct a request to obtain a new token. This is the OAuth2 password auth flow. + + API self._endpoint + '/app' (until /api/auth/6.5) + API self._endpoint + '/token'. (since /api/auth/6.6) + + :param organization: Optional name of an organization to be used for this token. + :param organization_id: Optional id of an organization to be used for this token. + :raises AuthenticationTokenError: When no auth_endpoint is set. + """ + with self._lock: + if not self.endpoint: + raise AuthenticationTokenError( + 'Token is invalid and endpoint (auth_endpoint) for obtaining is not set.') + + if not self._username or not self._password or not self._client_id or not self._client_secret: + msg = "" + if not self._username: + msg += "'username'" + if not self._password: + msg += (", " if msg else "") + "'password'" + if not self._client_id: + msg += (", " if msg else "") + "'client_id'" + if not self._client_secret: + msg += (", " if msg else "") + "'client_secret'" + raise AuthenticationTokenError( + "{} is missing required parameter(s) {}.".format(self.__class__.__name__, msg)) + + auth_api_version = float(self._version_info['auth']['version']) + + if auth_api_version >= 6.6: + url = self.endpoint + '/token' + content_type = 'application/x-www-form-urlencoded' + data = { + "grant_type": "password" + } + else: + url = self.endpoint + '/app' + content_type = 'application/json' + data = {} + + self._organization = organization + self._organization_id = organization_id + + data.update({ + "client_id": self._client_id, + "client_secret": self._client_secret, + "username": self._username, + "password": self._password, + "organization": self._organization, + "organization_id": self._organization_id + }) + + res = self.post(url, data, content_type=content_type) + self._token_info.parse_token_result(res, "{}.get_token".format(self.__class__.__name__)) + + +class CodeFlowAuthTokenApiHandler(AbstractRemoteTokenApiHandler): + """ + Implements the OAuth2 authorization_code auth flow. This is only the second step where you already + obtained the code parameter from the authorization server. + + Be aware, that obtaining a token can only be used once with the same value of *self._code*. + + The methods of this class are thread-safe, so it can be shared between several HIRO objects. + + It is built this way to avoid endless calling loops when resolving tokens. + """ + _scope: str + _redirect_uri: str + _code_verifier: str + _code: str + _state: str + + def __init__(self, + redirect_uri: str, + scope: str = None, + *args, **kwargs): + """ + Constructor + + Be aware, that obtaining a token via authentication_code flow can only be done once with the same value of + *self._code*. If a token refresh fails, you need to start a new authorization code flow. + + See parent :class:`AbstractRemoteTokenApiHandler` for a full description + of all remaining parameters. + + :param redirect_uri: The redirect_uri parameter for the authorization-redirect call. + :param scope: Optional scope for OAuth login. + :param client_id: OAuth client_id for authentication + :param client_secret: OAuth client_secret for authentication. This is optional here. + :param organization: Optional name of an organization to be used for the token requests. + :param organization_id: Optional id of an organization to be used for the token requests. Overrides + parameter *organization*. + :param secure_logging: If this is enabled, payloads that might contain sensitive information are not logged. + :param args: Unnamed parameter passthrough for parent class. + :param kwargs: Named parameter passthrough for parent class. + """ + super().__init__(*args, **kwargs) + + self._scope = scope + self._redirect_uri = redirect_uri + + def get_authorize_uri(self) -> str: + """ + Construct an authorization uri for your browser. + + :return: The uri for the browser. + """ + if not self._redirect_uri: + raise AuthenticationTokenError("redirect_uri is missing", 400) + + url = self.endpoint + "/authorize" + + # Initialize state and code_verifier anew with each call. + self._state = secrets.token_urlsafe(16) + self._code_verifier = pkce.generate_code_verifier(length=64) + + data = { + "response_type": "code", + "client_id": self._client_id, + "redirect_uri": self._redirect_uri, + "code_challenge": pkce.get_code_challenge(self._code_verifier), + "code_challenge_method": "S256", + "state": self._state, + "scope": self._scope + } + + return url + self._get_query_part(data) + + def handle_authorize_callback(self, state: str, code: str) -> None: + """ + Saves the code and matches the state from the authorization callback. + + :param state: The state returned by the callback. Must not have changed. + :param code: The one-time-code returned (used to obtain a full authorization token). + """ + if state != self._state: + raise AuthenticationTokenError("The parameter 'state' of the callback does not match.", 400) + + self._code = code + + def get_token(self, organization: str = None, organization_id: str = None) -> None: + """ + Construct a request to obtain a new token. This is the second step of the authorization_code flow. + + Be aware, that this method can only be used once with the same value of *self._code*. If a token has been + obtained here, this method cannot be re-used. If a token refresh fails, you need to start a new authorization + code flow. + + API self._endpoint + '/token' (since /auth/api/6.6) + + :param organization: Optional name of an organization to be used for this token. + :param organization_id: Optional id of an organization to be used for this token. + :raises AuthenticationTokenError: When no auth_endpoint is set or /api/auth/${version} is below 6.6. + """ + with self._lock: + if not self.endpoint: + raise AuthenticationTokenError( + 'Token is invalid and endpoint (auth_endpoint) for obtaining is not set.') + + if not self._code or not self._code_verifier or not self._client_id: + msg = "" + if not self._code: + msg += "'code'" + if not self._client_id: + msg += (", " if msg else "") + "'client_id'" + if not self._redirect_uri: + msg += (", " if msg else "") + "'redirect_uri'" + raise AuthenticationTokenError( + "{} is missing required parameter(s) {}.".format(self.__class__.__name__, msg)) + + auth_api_version = float(self._version_info['auth']['version']) + if auth_api_version < 6.6: + raise AuthenticationTokenError("Auth api version /api/auth/[version] has to be at least 6.6.") + + url = self.endpoint + '/token' + + # If a client_secret is present, set "client_id" and "client_secret", if not, only add "clientId" as + # form param (peculiar of WSO2). + if self._client_secret: + data = { + "client_id": self._client_id, + "client_secret": self._client_secret + } + else: + data = { + "clientId": self._client_id + } + + self._organization = organization + self._organization_id = organization_id + + data.update({ + "grant_type": "authorization_code", + "code": self._code, + "code_verifier": self._code_verifier, + "redirect_uri": self._redirect_uri, + "organization": self._organization, + "organization_id": self._organization_id + }) + + res = self.post(url, data, content_type='application/x-www-form-urlencoded') + self._token_info.parse_token_result(res, "{}.get_token".format(self.__class__.__name__)) + + def refresh_token(self, organization: str = None, organization_id: str = None) -> None: + """ + Construct a request to refresh an existing token. This fails immediately if no refresh_token is available. + + API self._endpoint + '/token'. (since /api/auth/6.6) + + :param organization: Optional name of an organization to be used for this token. + :param organization_id: Optional id of an organization to be used for this token. Overrides + parameter *organization*. + + :raises AuthenticationTokenError: When no auth_endpoint or refresh_token is set. + """ + if not self._token_info.refresh_token: + raise AuthenticationTokenError("Token cannot be refreshed without a refresh_token.") + + super().refresh_token(organization, organization_id) + + ################################################################################################################### # Root class for different API groups ################################################################################################################### diff --git a/src/hiro_graph_client/requirements.txt b/src/hiro_graph_client/requirements.txt index 38a1c6a..d1573b1 100644 --- a/src/hiro_graph_client/requirements.txt +++ b/src/hiro_graph_client/requirements.txt @@ -4,4 +4,5 @@ backoff==2.0.1 setuptools==62.1.0 websocket-client==1.3.2 -apscheduler==3.9.1 \ No newline at end of file +apscheduler==3.9.1 +pkce==1.0.3 \ No newline at end of file diff --git a/src/setup.py b/src/setup.py index e95ce12..6c75d7e 100644 --- a/src/setup.py +++ b/src/setup.py @@ -31,7 +31,8 @@ 'requests', 'backoff', 'websocket-client', - 'apscheduler' + 'apscheduler', + 'pkce' ], package_data={ name: ['VERSION'] diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index ca7a1ee..633d507 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -1,40 +1,106 @@ -from hiro_graph_client import PasswordAuthTokenApiHandler, HiroGraph, SSLConfig, GraphConnectionHandler -from .testconfig import CONFIG +import base64 +import gzip +import json + +from hiro_graph_client import CodeFlowAuthTokenApiHandler, PasswordAuthTokenApiHandler, SSLConfig, \ + GraphConnectionHandler, HiroApp, HiroAuth +from .testconfig import CONFIG_STAGE as CONFIG class TestClient: + connection_handler = GraphConnectionHandler( root_url=CONFIG.get('URL'), ssl_config=SSLConfig(verify=False) ) - hiro_api_handler = PasswordAuthTokenApiHandler( - username=CONFIG.get('USERNAME'), - password=CONFIG.get('PASSWORD'), - client_id=CONFIG.get('CLIENT_ID'), - client_secret=CONFIG.get('CLIENT_SECRET'), - secure_logging=False, - connection_handler=connection_handler - ) + def test_download_config(self): + hiro_api_handler = PasswordAuthTokenApiHandler( + username=CONFIG.get('USERNAME'), + password=CONFIG.get('PASSWORD'), + client_id=CONFIG.get('CLIENT_ID'), + client_secret=CONFIG.get('CLIENT_SECRET'), + secure_logging=False, + connection_handler=self.connection_handler + ) - hiro_api_handler2 = PasswordAuthTokenApiHandler( - username=CONFIG.get('USERNAME'), - password=CONFIG.get('PASSWORD'), - client_id=CONFIG.get('CLIENT_ID'), - client_secret=CONFIG.get('CLIENT_SECRET'), - secure_logging=False, - connection_handler=connection_handler - ) + hiro_app: HiroApp = HiroApp(api_handler=hiro_api_handler) - def test_simple_query(self): - hiro_client: HiroGraph = HiroGraph(api_handler=self.hiro_api_handler) + result = hiro_app.get_config() + + content = result.get("content") + content_type = result.get("type") + + content_bytes = base64.b64decode(content) if "base64" in content_type else bytearray(content, encoding='utf8') + + config_data = str(gzip.decompress(content_bytes), encoding='utf8') if "gz" in content_type else str( + content_bytes, encoding='utf8') + + with open("connector.conf", "w") as file: + print(config_data, file=file) + + def test_connect(self): + hiro_api_handler = PasswordAuthTokenApiHandler( + username=CONFIG.get('USERNAME'), + password=CONFIG.get('PASSWORD'), + client_id=CONFIG.get('CLIENT_ID'), + client_secret=CONFIG.get('CLIENT_SECRET'), + secure_logging=False, + connection_handler=self.connection_handler + ) + + hiro_auth: HiroAuth = HiroAuth(api_handler=hiro_api_handler) + + print(json.dumps(hiro_auth.get_identity(), indent=2)) - hiro_client.get_node(node_id="ckqjkt42s0fgf0883pf0cb0hx_ckqjl014l0hvr0883hxcvmcwq", meta=True) + def test_refresh_token(self): + hiro_api_handler = PasswordAuthTokenApiHandler( + username=CONFIG.get('USERNAME'), + password=CONFIG.get('PASSWORD'), + client_id=CONFIG.get('CLIENT_ID'), + client_secret=CONFIG.get('CLIENT_SECRET'), + secure_logging=False, + connection_handler=self.connection_handler + ) - hiro_client: HiroGraph = HiroGraph(api_handler=self.hiro_api_handler2) + handler = hiro_api_handler - hiro_client.get_node(node_id="ckqjkt42s0fgf0883pf0cb0hx_ckqjl014l0hvr0883hxcvmcwq", meta=True) + handler.get_token() + handler.refresh_token() + hiro_auth: HiroAuth = HiroAuth(api_handler=handler) + + print(json.dumps(hiro_auth.get_identity(), indent=2)) + + handler.revoke_token() + + print(json.dumps(hiro_auth.get_identity(), indent=2)) + + def test_auth_code_flow(self): + hiro_api_handler = CodeFlowAuthTokenApiHandler( + client_id=CONFIG.get('CLIENT_ID'), + code='4b12c5f1-cdd7-3487-a5b0-82305b011b67', + code_verifier='BQ3kyADe6RKYHttFwGPwvLrX_B6zwCr2vNRe00TQLfoRo-HYhUqM8sPKUQlkhbwUQxli2ZneFhFqx4xbltI4WQ', + redirect_uri='http://wksw-whuebner.arago.de:8080/oauth2/app/callback.xhtml', + secure_logging=False, + connection_handler=self.connection_handler + ) + + hiro_api_handler.get_token() + print(json.dumps(hiro_api_handler.decode_token(), indent=2)) + + hiro_api_handler.refresh_token() + print(json.dumps(hiro_api_handler.decode_token(), indent=2)) + + def test_simple_query(self): + # hiro_client: HiroGraph = HiroGraph(api_handler=self.hiro_api_handler) + # + # hiro_client.get_node(node_id="ckqjkt42s0fgf0883pf0cb0hx_ckqjl014l0hvr0883hxcvmcwq", meta=True) + # + # hiro_client: HiroGraph = HiroGraph(api_handler=self.hiro_api_handler2) + # + # hiro_client.get_node(node_id="ckqjkt42s0fgf0883pf0cb0hx_ckqjl014l0hvr0883hxcvmcwq", meta=True) + # # def query(_id: str): # hiro_client.get_node(node_id=_id) # @@ -45,3 +111,4 @@ def test_simple_query(self): # t1.start() # t2.start() # t3.start() + pass