diff --git a/.flake8 b/.flake8 index 62a48b1697..88abba3426 100644 --- a/.flake8 +++ b/.flake8 @@ -126,7 +126,7 @@ per-file-ignores = src/awx_plugins/credentials/dsv.py: ANN003, ANN201, D100, D103, P103, WPS210 # NOTE: `awx_plugins.credentials.github_app` may need restructuring to comply with WPS 202. src/awx_plugins/credentials/github_app.py: WPS202 - src/awx_plugins/credentials/hashivault.py: ANN003, ANN201, B950, C901, CCR001, D100, D103, LN001, N400, WPS202, WPS204, WPS210, WPS221, WPS223, WPS229, WPS231, WPS232, WPS331, WPS336, WPS337, WPS432, WPS454 + src/awx_plugins/credentials/hashivault.py: ANN003, ANN201, B950, C901, CCR001, D100, D103, LN001, N400, WPS201, WPS202, WPS204, WPS210, WPS221, WPS223, WPS229, WPS231, WPS232, WPS331, WPS336, WPS337, WPS432, WPS454 src/awx_plugins/credentials/injectors.py: ANN001, ANN201, ANN202, C408, D100, D103, WPS110, WPS111, WPS202, WPS210, WPS347, WPS433, WPS440 src/awx_plugins/credentials/plugin.py: ANN001, ANN002, ANN101, ANN201, ANN204, B010, D100, D101, D103, D105, D107, D205, D400, E731, WPS432, WPS433, WPS440, WPS442, WPS601 src/awx_plugins/credentials/plugins.py: B950,D100, D101, D103, D105, D107, D205, D400, LN001, WPS204, WPS229, WPS433, WPS440 diff --git a/.mypy.ini b/.mypy.ini index 8520d2ae1a..27122d135b 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -111,3 +111,12 @@ disallow_any_expr = false disallow_any_expr = false # fails on `@hypothesis.given()`: disallow_any_decorated = false + +# It is sometimes necessary to access "officially-private" attributes in our +# own tests for the purposes of assertion, mocking and spying on objects. And +# that's fine since it's within the same repository fully under our control, +# and so we don't need to have MyPy complain about those in this specific +# context. This is why we disable `attr-defined` in our tests [1]. +# +# [1] https://til.codeinthehole.com/posts/how-to-handle-convenience-imports-with-mypy/ +disable_error_code = attr-defined diff --git a/docs/conf.py b/docs/conf.py index 1cb0f78ee1..8e116a2339 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -233,6 +233,12 @@ 'py:class', 'awx_plugins.interfaces._temporary_private_credential_api.Credential', ), + ( # generic return type variable: + 'py:class', + 'awx_plugins.credentials.hashivault._RT', + ), + ('py:class', '_PT'), # generic ParamSpec type variable + ('py:class', '_contextvars.ContextVar'), # unresolved context var type ('py:class', 'EnvVarsType'), ] diff --git a/src/awx_plugins/credentials/hashivault.py b/src/awx_plugins/credentials/hashivault.py index 293af0e6dd..93ee6b7735 100644 --- a/src/awx_plugins/credentials/hashivault.py +++ b/src/awx_plugins/credentials/hashivault.py @@ -1,9 +1,14 @@ # FIXME: the following violations must be addressed gradually and unignored # mypy: disable-error-code="arg-type, no-untyped-call, no-untyped-def" +import contextlib as _ctx +import contextvars as _ctx_vars +import functools as _functools import os import pathlib import time +import typing as _t +from collections import abc as _abc from urllib.parse import urljoin from awx_plugins.interfaces._temporary_private_django_api import ( # noqa: WPS436 @@ -16,6 +21,14 @@ from .plugin import CertFiles, CredentialPlugin, raise_for_status +_AUTH_TOKEN: _ctx_vars.ContextVar[str] = _ctx_vars.ContextVar('_AUTH_TOKEN') +"""Authentication token for use in plugin handlers.""" + + +class _EmptyKwargs(_t.TypedDict): + """Schema for zero keyword arguments.""" + + # Base input fields url_field: _types.FieldDict = { 'id': 'url', @@ -481,6 +494,76 @@ def workload_identity_auth(**kwargs): return {'role': kwargs.get('jwt_role'), 'jwt': workload_identity_token} +def _revoke_self_token( + *, + vault_token: str, + url: str, + namespace: str, + cacert: str | None = None, +) -> None: + """Revoke the passed-in Vault token.""" + url = urljoin(url, 'v1/auth/token/revoke-self') + sess = requests.Session() + sess.headers['X-Vault-Token'] = vault_token + if namespace != '': + sess.headers['X-Vault-Namespace'] = namespace + with CertFiles(cacert) as cert: + resp = sess.post(url, verify=cert, timeout=30) + resp.raise_for_status() + + +@_ctx.contextmanager +def _vault_token(**kwargs: str) -> _abc.Iterator[str]: + """Context manager that yields a Vault token and revokes it on exit if obtained via workload identity.""" + is_oidc_context = 'workload_identity_token' in kwargs + token = handle_auth(**kwargs) + try: + yield token + finally: + # Only revoke tokens obtained via OIDC authentication + if is_oidc_context: + _revoke_self_token( + vault_token=token, + url=kwargs['url'], + namespace=kwargs.get('namespace', ''), + cacert=kwargs.get('cacert'), + ) + + +@_ctx.contextmanager +def _token_in_context(token: str, /) -> _abc.Iterator[None]: + """Set a token for the execution context lifetime.""" + var_ctx_token = _AUTH_TOKEN.set(token) + try: + yield + finally: + _AUTH_TOKEN.reset(var_ctx_token) + + +# Param Spec to represent decorated function parameters +_PT = _t.ParamSpec( # FIXME: Use [_RT, **_PT] in the signature in Python 3.12 + '_PT', +) +# TypeVar to represent decorated function return type +_RT = _t.TypeVar('_RT') + + +def _inject_auth_token_with_revocation( + decorated_function: _t.Callable[_PT, _RT], + /, +) -> _t.Callable[_PT, _RT]: + @_functools.wraps(decorated_function) + def _decorate_the_function_with_revocation( # noqa: WPS430 -- in-decorator + *args: _PT.args, + **kwargs: _PT.kwargs, + ) -> _RT: + with _vault_token(**kwargs) as token: + with _token_in_context(token): + return decorated_function(*args, **kwargs) + + return _decorate_the_function_with_revocation + + def method_auth(**kwargs): # get auth method specific params request_kwargs = {'json': kwargs['auth_param'], 'timeout': 30} @@ -522,15 +605,22 @@ def method_auth(**kwargs): return token -def kv_backend(**kwargs): - token = handle_auth(**kwargs) - url = kwargs['url'] - secret_path = kwargs['secret_path'] - secret_backend = kwargs.get('secret_backend') - secret_key = kwargs.get('secret_key') - cacert = kwargs.get('cacert') - api_version = kwargs['api_version'] - +@_inject_auth_token_with_revocation +# NOTE: The "too many args" rules of flake8 and pylint are disabled due to such +# NOTE: many arguments being a common public plugin API at the moment. +# pylint: disable-next=too-many-arguments +def kv_backend( # noqa: WPS211 -- the same as too-many-arguments + *, + url: str, + api_version: str, + secret_path: str, + secret_key: str | None = None, + secret_backend: str | None = None, + secret_version: str | None = None, + cacert: str | None = None, + namespace: str | None = None, + **_discarded_kwargs: _t.Unpack[_EmptyKwargs], +) -> str: request_kwargs = { 'timeout': 30, 'allow_redirects': False, @@ -538,16 +628,16 @@ def kv_backend(**kwargs): sess = requests.Session() sess.mount(url, requests.adapters.HTTPAdapter(max_retries=5)) - sess.headers['Authorization'] = f'Bearer {token}' + sess.headers['Authorization'] = f'Bearer {_AUTH_TOKEN.get()}' # Compatibility header for older installs of Hashicorp Vault - sess.headers['X-Vault-Token'] = token - if kwargs.get('namespace'): - sess.headers['X-Vault-Namespace'] = kwargs['namespace'] + sess.headers['X-Vault-Token'] = _AUTH_TOKEN.get() + if namespace: + sess.headers['X-Vault-Namespace'] = namespace if api_version == 'v2': - if kwargs.get('secret_version'): + if secret_version: request_kwargs['params'] = { # type: ignore[assignment] # FIXME - 'version': kwargs['secret_version'], + 'version': secret_version, } if secret_backend: path_segments = [secret_backend, 'data', secret_path] @@ -556,7 +646,6 @@ def kv_backend(**kwargs): mount_point, *path = pathlib.Path( secret_path.lstrip(os.sep), ).parts - '/'.join(path) except Exception: mount_point, path = secret_path, [] # https://www.vaultproject.io/api/secret/kv/kv-v2.html#read-secret-version @@ -593,40 +682,51 @@ def kv_backend(**kwargs): ) and ('data' in json['data']) ): - return json['data']['data'][secret_key] - return json['data'][secret_key] + return str(json['data']['data'][secret_key]) + return str(json['data'][secret_key]) except KeyError: - raise RuntimeError(f'{secret_key} is not present at {secret_path}') - return json['data'] - - -def ssh_backend(**kwargs): - token = handle_auth(**kwargs) - url = urljoin(kwargs['url'], 'v1') - secret_path = kwargs['secret_path'] - role = kwargs['role'] - cacert = kwargs.get('cacert') + raise RuntimeError( + f'{secret_key} is not present at {secret_path}', + ) + return str(json['data']) + + +@_inject_auth_token_with_revocation +# NOTE: The "too many args" rules of flake8 and pylint are disabled due to such +# NOTE: many arguments being a common public plugin API at the moment. +# pylint: disable-next=too-many-arguments +def ssh_backend( # noqa: WPS211 -- the same as too-many-arguments + *, + url: str, + secret_path: str, + role: str, + public_key: str, + cacert: str | None = None, + namespace: str | None = None, + valid_principals: str | None = None, + **_discarded_kwargs: _t.Unpack[_EmptyKwargs], +) -> str: + url = urljoin(url, 'v1') request_kwargs = { 'timeout': 30, 'allow_redirects': False, + 'json': { + 'public_key': public_key, + }, } - - request_kwargs['json'] = { # type: ignore[assignment] # FIXME - 'public_key': kwargs['public_key'], - } - if kwargs.get('valid_principals'): + if valid_principals: request_kwargs['json'][ # type: ignore[index] # FIXME 'valid_principals' - ] = kwargs['valid_principals'] + ] = valid_principals sess = requests.Session() sess.mount(url, requests.adapters.HTTPAdapter(max_retries=5)) - sess.headers['Authorization'] = f'Bearer {token}' - if kwargs.get('namespace'): - sess.headers['X-Vault-Namespace'] = kwargs['namespace'] + sess.headers['Authorization'] = f'Bearer {_AUTH_TOKEN.get()}' + if namespace: + sess.headers['X-Vault-Namespace'] = namespace # Compatibility header for older installs of Hashicorp Vault - sess.headers['X-Vault-Token'] = token + sess.headers['X-Vault-Token'] = _AUTH_TOKEN.get() # https://www.vaultproject.io/api/secret/ssh/index.html#sign-ssh-key request_url = '/'.join([url, secret_path, 'sign', role]).rstrip('/') @@ -642,7 +742,7 @@ def ssh_backend(**kwargs): else: break raise_for_status(resp) - return resp.json()['data']['signed_key'] + return str(resp.json()['data']['signed_key']) hashivault_kv_plugin = CredentialPlugin( diff --git a/tests/unit/credentials/hashivault_test.py b/tests/unit/credentials/hashivault_test.py index 2ed48fb441..918d7ef94a 100644 --- a/tests/unit/credentials/hashivault_test.py +++ b/tests/unit/credentials/hashivault_test.py @@ -1,7 +1,13 @@ +# pylint: disable=protected-access # tests access private methods legitimately """Tests for HashiCorp Vault credential plugins.""" import typing as _t +# NOTE: The forbidden import here is only used for typing, which warrants +# NOTE: suppressing the respective pylint and ruff rules. +# pylint: disable-next=deprecated-module,preferred-module +from unittest import mock as _unittest_mock # noqa: TID251 -- forbidden import + import pytest from pytest_mock import MockerFixture @@ -395,6 +401,297 @@ def test_non_oidc_plugins_have_no_internal_fields( assert internal_fields == [] +@pytest.fixture +def handle_auth_mock(mocker: MockerFixture) -> object: + """Make a mocked ``handle_auth`` callable returning 'test_token'.""" + return mocker.patch.object( + hashivault, + 'handle_auth', + autospec=True, + return_value='test_token', + ) + + +@pytest.fixture +def session_class_mock(mocker: MockerFixture) -> object: + """Make a mocked ``requests.Session`` class dummy.""" + return mocker.patch.object( + hashivault.requests, + 'Session', + autospec=True, + ) + + +@pytest.fixture +def _cert_files_mock(mocker: MockerFixture) -> None: + """Replace ``CertFiles`` class with a dummy.""" + mocker.patch.object( + hashivault, + 'CertFiles', + autospec=True, + ).return_value.__enter__.return_value = 'cert_path' + + +@pytest.fixture +def _suppress_raise_for_status(mocker: MockerFixture) -> None: + """Suppress ``requests``' HTTP return code checks.""" + mocker.patch.object(hashivault, 'raise_for_status', autospec=True) + + +def test_vault_token_no_workload_identity( + handle_auth_mock: _unittest_mock.Mock, + session_class_mock: _unittest_mock.Mock, +) -> None: + """Test ``_vault_token`` context manager doesn't revoke token without workload_identity_token.""" + vault_token_kwargs = { + 'url': 'https://vault.example.com', + 'token': 'test_token', + } + + with hashivault._vault_token(**vault_token_kwargs) as token: + assert token == 'test_token' + + handle_auth_mock.assert_called_once_with(**vault_token_kwargs) + session_class_mock.return_value.post.assert_not_called() + + +@pytest.mark.parametrize( + ('extra_kwargs', 'expected_headers'), + ( + pytest.param( + {}, + {'X-Vault-Token': 'test_token'}, + id='without-namespace', + ), + pytest.param( + {'namespace': 'test-namespace'}, + { + 'X-Vault-Token': 'test_token', + 'X-Vault-Namespace': 'test-namespace', + }, + id='with-namespace', + ), + ), +) +@pytest.mark.usefixtures('_cert_files_mock') +def test_vault_token_revokes_oidc_token( + handle_auth_mock: _unittest_mock.Mock, + session_class_mock: _unittest_mock.Mock, + extra_kwargs: dict[str, str], + expected_headers: dict[str, str], +) -> None: + """Test ``_vault_token`` context manager revokes token for workload identity auth.""" + mock_session = session_class_mock.return_value + mock_session.headers = {} + + kwargs = { + 'url': 'https://vault.example.com', + 'workload_identity_token': 'jwt_token', + 'jwt_role': 'test_role', + 'default_auth_path': 'jwt', + **extra_kwargs, + } + + with hashivault._vault_token(**kwargs) as token: + assert token == 'test_token' + + handle_auth_mock.assert_called_once_with(**kwargs) + for header, header_contents in expected_headers.items(): + assert mock_session.headers[header] == header_contents + mock_session.post.assert_called_once() + assert 'auth/token/revoke-self' in mock_session.post.call_args[0][0] + mock_session.post.return_value.raise_for_status.assert_called_once() + + +@pytest.mark.usefixtures('_cert_files_mock') +def test_vault_token_revoke_failure( + handle_auth_mock: _unittest_mock.Mock, + session_class_mock: _unittest_mock.Mock, +) -> None: + """Test ``_vault_token`` context manager raises when token revocation fails.""" + mock_session = session_class_mock.return_value + mock_session.headers = {} + mock_session.post.return_value.raise_for_status.side_effect = ( + hashivault.requests.HTTPError( + '403 Client Error: Forbidden for url: ...', + ) + ) + + kwargs = { + 'url': 'https://vault.example.com', + 'workload_identity_token': 'jwt_token', + 'jwt_role': 'test_role', + 'default_auth_path': 'jwt', + } + + def _use_vault_token() -> None: + with hashivault._vault_token(**kwargs) as token: + assert token == 'test_token' + + with pytest.raises( + hashivault.requests.HTTPError, + match='403 Client Error', + ): + _use_vault_token() + + handle_auth_mock.assert_called_once_with(**kwargs) + + +@pytest.mark.parametrize( + ('backend_func', 'extra_kwargs'), + ( + pytest.param( + hashivault.kv_backend, + { + 'api_version': 'v1', + 'secret_key': 'password', + 'secret_path': '/secret/path', + }, + id='kv-backend-v1', + ), + pytest.param( + hashivault.kv_backend, + { + 'api_version': 'v2', + 'secret_key': 'password', + 'secret_path': '/secret/path', + }, + id='kv-backend-v2', + ), + pytest.param( + hashivault.kv_backend, + { + 'api_version': 'v2', + 'secret_key': 'password', + 'namespace': 'test-namespace', + 'secret_path': '/secret/path', + }, + id='kv-backend-v2-with-namespace', + ), + pytest.param( + hashivault.kv_backend, + { + 'api_version': 'v2', + 'secret_key': 'password', + 'secret_version': '3', + 'secret_path': '/secret/path', + }, + id='kv-backend-v2-with-secret-version', + ), + pytest.param( + hashivault.kv_backend, + { + 'api_version': 'v2', + 'secret_key': 'password', + 'secret_backend': 'kv', + 'secret_path': '/secret/path', + }, + id='kv-backend-v2-with-secret-backend', + ), + pytest.param( + hashivault.kv_backend, + { + 'api_version': 'v2', + 'secret_key': 'password', + 'namespace': 'test-namespace', + 'secret_backend': 'kv', + 'secret_version': '5', + 'secret_path': '/secret/path', + }, + id='kv-backend-v2-with-all-params', + ), + pytest.param( + hashivault.ssh_backend, + { + 'secret_path': '/ssh', + 'role': 'test_ssh_role', + 'public_key': 'ssh-rsa AAAAB...', + }, + id='ssh-backend', + ), + pytest.param( + hashivault.ssh_backend, + { + 'namespace': 'test-namespace', + 'secret_path': '/ssh', + 'role': 'test_ssh_role', + 'public_key': 'ssh-rsa AAAAB...', + }, + id='ssh-backend-with-namespace', + ), + ), +) +@pytest.mark.usefixtures('_cert_files_mock', '_suppress_raise_for_status') +def test_backend_revokes_oidc_token( + handle_auth_mock: _unittest_mock.Mock, + mocker: MockerFixture, + session_class_mock: _unittest_mock.Mock, + backend_func: _t.Callable[[dict[str, str]], str], + extra_kwargs: dict[str, str], +) -> None: + """Test backend functions revoke token via context manager for OIDC auth.""" + # Common base kwargs for all OIDC auth scenarios + base_kwargs = { + 'url': 'https://vault.example.com', + 'workload_identity_token': 'jwt_token', + 'jwt_role': 'test_role', + 'default_auth_path': 'jwt', + } + + backend_kwargs = {**base_kwargs, **extra_kwargs} + + mock_get_or_post_session = mocker.Mock() + mock_get_or_post_session.headers = {} + + # Configure mock response for secret fetch based on backend type + # and use it as a get or post method on the session mock + mock_secret_response = mocker.Mock() + mock_secret_response.status_code = 200 + + # MyPy is unhappy due to incomplete typing of the `backend_func` param but + # we don't care too much for now. + if backend_func is hashivault.kv_backend: # type: ignore[comparison-overlap] + mock_get_or_post_session.get.return_value = mock_secret_response + + api_version = extra_kwargs.get('api_version', 'v1') + # kv_backend v1 expects {'data': {secret_key: value}} + mock_secret_response.json.return_value = { + 'data': {'password': 'test_value'}, + } + if api_version == 'v2': + # kv_backend v2 expects {'data': {'data': {secret_key: value}}} + mock_secret_response.json.return_value = { + 'data': mock_secret_response.json.return_value, + } + else: # ssh_backend + mock_get_or_post_session.post.return_value = mock_secret_response + + # ssh_backend expects {'data': {'signed_key': value}} + mock_secret_response.json.return_value = { + 'data': {'signed_key': 'test_signed_key'}, + } + + mock_revoke_session = mocker.Mock() + mock_revoke_session.headers = {} + + # requests.Session is called twice: once for secret fetch, once for revoke + session_class_mock.side_effect = [ + mock_get_or_post_session, + mock_revoke_session, + ] + + # Call the backend function to retrieve the secret + # Adding ignore[call-arg] since adding the explicit args to the backend + # functions makes the linter trigger the "too few arguments supplied" + backend_func(**backend_kwargs) # type: ignore[call-arg] + + handle_auth_mock.assert_called_once() + # Verify revocation was attempted + mock_revoke_session.post.assert_called_once() + assert 'auth/token/revoke-self' in mock_revoke_session.post.call_args[0][0] + mock_revoke_session.post.return_value.raise_for_status.assert_called_once() + + @pytest.mark.parametrize( 'plugin', (