diff --git a/awx/api/serializers.py b/awx/api/serializers.py index 2a8e94d83392..dd03400bdc34 100644 --- a/awx/api/serializers.py +++ b/awx/api/serializers.py @@ -2932,6 +2932,19 @@ def to_representation(self, data): field['label'] = _(field['label']) if 'help_text' in field: field['help_text'] = _(field['help_text']) + + # Deep copy inputs to avoid modifying the original model data + inputs = value.get('inputs') + if not isinstance(inputs, dict): + inputs = {} + value['inputs'] = copy.deepcopy(inputs) + fields = value['inputs'].get('fields', []) + if not isinstance(fields, list): + fields = [] + + # Filter out internal fields from the API response + value['inputs']['fields'] = [f for f in fields if not f.get('internal')] + return value def filter_field_metadata(self, fields, method): diff --git a/awx/api/views/__init__.py b/awx/api/views/__init__.py index 1456e8bb160b..b31888662350 100644 --- a/awx/api/views/__init__.py +++ b/awx/api/views/__init__.py @@ -14,6 +14,7 @@ import time from base64 import b64encode from collections import OrderedDict +from jwt import decode as _jwt_decode from urllib3.exceptions import ConnectTimeoutError @@ -58,9 +59,11 @@ from ansible_base.lib.utils.requests import get_remote_hosts from ansible_base.rbac.models import RoleEvaluation from ansible_base.lib.utils.schema import extend_schema_if_available +from ansible_base.lib.workload_identity.controller import AutomationControllerJobScope # AWX -from awx.main.tasks.system import send_notifications, update_inventory_computed_fields +from awx.main.tasks.jobs import retrieve_workload_identity_jwt_with_claims +from awx.main.tasks.system import flag_enabled, send_notifications, update_inventory_computed_fields from awx.main.access import get_user_queryset from awx.api.generics import ( APIView, @@ -163,6 +166,95 @@ def api_exception_handler(exc, context): return exception_handler(exc, context) +def _get_workload_identity_token(job_template: models.JobTemplate, jwt_aud: str) -> str: + claims = { + AutomationControllerJobScope.CLAIM_ORGANIZATION_NAME: job_template.organization.name, + AutomationControllerJobScope.CLAIM_ORGANIZATION_ID: job_template.organization.id, + AutomationControllerJobScope.CLAIM_PROJECT_NAME: job_template.project.name, + AutomationControllerJobScope.CLAIM_PROJECT_ID: job_template.project.id, + AutomationControllerJobScope.CLAIM_JOB_TEMPLATE_NAME: job_template.name, + AutomationControllerJobScope.CLAIM_JOB_TEMPLATE_ID: job_template.id, + AutomationControllerJobScope.CLAIM_PLAYBOOK_NAME: job_template.playbook, + } + # Get a Workload Identity Token + return retrieve_workload_identity_jwt_with_claims( + claims=claims, + audience=jwt_aud, + scope=AutomationControllerJobScope.name, + ) + + +def _validate_and_get_job_template(job_template_id): + """Validate job template ID and return the JobTemplate instance. + + Returns: + tuple: (JobTemplate instance or None, error_message or None) + """ + if job_template_id is None: + return None, _('Job template ID is required.') + + try: + return models.JobTemplate.objects.get(id=int(job_template_id)), None + except ValueError: + return None, _('Job template ID must be an integer.') + except models.JobTemplate.DoesNotExist: + return None, _('Job template with ID %(id)s does not exist.') % {'id': job_template_id} + + +def _decode_jwt_payload_for_display(jwt_token): + """Decode JWT payload for display purposes only (signature not verified). + + This is safe because the JWT was just created by AWX and is only decoded + to show the user what claims are being sent to the external system. + The external system will perform proper signature verification. + """ + # NOSONAR - Signature verification intentionally disabled for display-only decoding + return _jwt_decode(jwt_token, algorithms=["RS256"], options={"verify_signature": False}) + + +def _handle_oidc_credential_test(credential_type_inputs, backend_kwargs, request): + """ + Handle OIDC workload identity token generation for external credential test endpoints. + + Returns: + tuple: (response_body dict, error Response or None) + Modifies backend_kwargs in place to add workload_identity_token. + """ + response_body = {} + + if not flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'): + return response_body, None + + # Check if credential type has an internal workload_identity_token field + fields = credential_type_inputs.get('fields', []) + has_oidc_field = any(field.get('internal') and field.get('id') == 'workload_identity_token' for field in fields) + + if not has_oidc_field: + return response_body, None + + # Validate job template + job_template_id = backend_kwargs.pop('job_template_id', None) + job_template, error_msg = _validate_and_get_job_template(job_template_id) + if error_msg: + response_body['details'] = {'error_message': error_msg} + return response_body, Response(response_body, status=status.HTTP_400_BAD_REQUEST) + + # Check user access + if not request.user.can_access(models.JobTemplate, 'read', job_template): + raise PermissionDenied(_('You do not have access to job template with id: %(id)s.') % {'id': job_template.id}) + + # Generate workload identity token + try: + jwt_token = _get_workload_identity_token(job_template, backend_kwargs.pop('jwt_aud', None)) + backend_kwargs['workload_identity_token'] = jwt_token + response_body['details'] = {'sent_jwt_payload': _decode_jwt_payload_for_display(jwt_token)} + except RuntimeError as exc: + response_body['details'] = {'error_message': str(exc)} + return response_body, Response(response_body, status=status.HTTP_400_BAD_REQUEST) + + return response_body, None + + class DashboardView(APIView): deprecated = True @@ -1622,23 +1714,32 @@ def post(self, request, *args, **kwargs): if value != '$encrypted$': backend_kwargs[field_name] = value backend_kwargs.update(request.data.get('metadata', {})) + + # Add extra test functionality for OIDC-enabled credential types + response_body, error_response = _handle_oidc_credential_test(obj.credential_type.inputs, backend_kwargs, request) + if error_response: + return error_response + try: with set_environ(**settings.AWX_TASK_ENV): obj.credential_type.plugin.backend(**backend_kwargs) - return Response({}, status=status.HTTP_202_ACCEPTED) + return Response(response_body, status=status.HTTP_202_ACCEPTED) except requests.exceptions.HTTPError: message = """Test operation is not supported for credential type {}. This endpoint only supports credentials that connect to external secret management systems such as CyberArk, HashiCorp Vault, or cloud-based secret managers.""".format(obj.credential_type.kind) - return Response({'detail': message}, status=status.HTTP_400_BAD_REQUEST) + response_body.setdefault('details', {})['error_message'] = message + return Response(response_body, status=status.HTTP_400_BAD_REQUEST) except Exception as exc: - message = exc.__class__.__name__ + # Use the exception message if available, otherwise fall back to the class name + message = str(exc) if str(exc) else exc.__class__.__name__ exc_args = getattr(exc, 'args', []) for a in exc_args: if isinstance(getattr(a, 'reason', None), ConnectTimeoutError): message = str(a.reason) - return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST) + response_body.setdefault('details', {})['error_message'] = message + return Response(response_body, status=status.HTTP_400_BAD_REQUEST) class CredentialInputSourceDetail(RetrieveUpdateDestroyAPIView): @@ -1685,19 +1786,28 @@ def post(self, request, *args, **kwargs): obj = self.get_object() backend_kwargs = request.data.get('inputs', {}) backend_kwargs.update(request.data.get('metadata', {})) + + # Add extra test functionality for OIDC-enabled credential types + response_body, error_response = _handle_oidc_credential_test(obj.inputs, backend_kwargs, request) + if error_response: + return error_response + try: obj.plugin.backend(**backend_kwargs) - return Response({}, status=status.HTTP_202_ACCEPTED) + return Response(response_body, status=status.HTTP_202_ACCEPTED) except requests.exceptions.HTTPError as exc: message = 'HTTP {}'.format(exc.response.status_code) - return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST) + response_body.setdefault('details', {})['error_message'] = message + return Response(response_body, status=status.HTTP_400_BAD_REQUEST) except Exception as exc: - message = exc.__class__.__name__ + # Use the exception message if available, otherwise fall back to the class name + message = str(exc) if str(exc) else exc.__class__.__name__ args_exc = getattr(exc, 'args', []) for a in args_exc: if isinstance(getattr(a, 'reason', None), ConnectTimeoutError): message = str(a.reason) - return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST) + response_body.setdefault('details', {})['error_message'] = message + return Response(response_body, status=status.HTTP_400_BAD_REQUEST) class HostRelatedSearchMixin(object): diff --git a/awx/main/tasks/jobs.py b/awx/main/tasks/jobs.py index e6f6a893f1f4..dc93a4dd8347 100644 --- a/awx/main/tasks/jobs.py +++ b/awx/main/tasks/jobs.py @@ -94,7 +94,7 @@ # Workload Identity from ansible_base.lib.workload_identity.controller import AutomationControllerJobScope -from ansible_base.resource_registry.workload_identity_client import get_workload_identity_client +from awx.main.utils.workload_identity import retrieve_workload_identity_jwt_with_claims logger = logging.getLogger('awx.main.tasks.jobs') @@ -168,14 +168,12 @@ def retrieve_workload_identity_jwt( Raises: RuntimeError: if the workload identity client is not configured. """ - client = get_workload_identity_client() - if client is None: - raise RuntimeError("Workload identity client is not configured") - claims = populate_claims_for_workload(unified_job) - kwargs = {"claims": claims, "scope": scope, "audience": audience} - if workload_ttl_seconds: - kwargs["workload_ttl_seconds"] = workload_ttl_seconds - return client.request_workload_jwt(**kwargs).jwt + return retrieve_workload_identity_jwt_with_claims( + populate_claims_for_workload(unified_job), + audience, + scope, + workload_ttl_seconds, + ) def with_path_cleanup(f): diff --git a/awx/main/tests/functional/api/test_credential_type.py b/awx/main/tests/functional/api/test_credential_type.py index ed0f1e9f28fb..29f1875f438e 100644 --- a/awx/main/tests/functional/api/test_credential_type.py +++ b/awx/main/tests/functional/api/test_credential_type.py @@ -159,7 +159,8 @@ def test_create_as_admin(get, post, admin): response = get(reverse('api:credential_type_list'), admin) assert response.data['count'] == 1 assert response.data['results'][0]['name'] == 'Custom Credential Type' - assert response.data['results'][0]['inputs'] == {} + # Serializer normalizes empty inputs to {'fields': []} + assert response.data['results'][0]['inputs'] == {'fields': []} assert response.data['results'][0]['injectors'] == {} assert response.data['results'][0]['managed'] is False @@ -474,3 +475,96 @@ def test_credential_type_rbac_external_test(post, alice, admin, credentialtype_e data = {'inputs': {}, 'metadata': {}} assert post(url, data, admin).status_code == 202 assert post(url, data, alice).status_code == 403 + + +# --- Tests for internal field filtering with None/invalid inputs --- + + +@pytest.mark.django_db +def test_credential_type_with_none_inputs(get, admin): + """Test that credential type with empty inputs dict works correctly.""" + # Create a credential type with empty dict + ct = CredentialType.objects.create( + kind='cloud', + name='Test Type', + managed=False, + inputs={}, # Empty dict, not None (DB has NOT NULL constraint) + ) + + url = reverse('api:credential_type_detail', kwargs={'pk': ct.pk}) + response = get(url, admin) + assert response.status_code == 200 + # Should have normalized inputs to empty dict + assert 'inputs' in response.data + assert isinstance(response.data['inputs'], dict) + assert response.data['inputs']['fields'] == [] + + +@pytest.mark.django_db +def test_credential_type_with_invalid_inputs_type(get, admin): + """Test that credential type with non-dict inputs doesn't cause errors.""" + # Create a credential type with invalid inputs type + ct = CredentialType.objects.create(kind='cloud', name='Test Type', managed=False, inputs={'fields': 'not-a-list'}) + + url = reverse('api:credential_type_detail', kwargs={'pk': ct.pk}) + response = get(url, admin) + assert response.status_code == 200 + # Should gracefully handle invalid fields type + assert 'inputs' in response.data + assert response.data['inputs']['fields'] == [] + + +@pytest.mark.django_db +def test_credential_type_filters_internal_fields(get, admin): + """Test that internal fields are filtered from API responses.""" + ct = CredentialType.objects.create( + kind='cloud', + name='Test OIDC Type', + managed=False, + inputs={ + 'fields': [ + {'id': 'url', 'label': 'URL', 'type': 'string'}, + {'id': 'token', 'label': 'Token', 'type': 'string', 'secret': True, 'internal': True}, + {'id': 'public_field', 'label': 'Public', 'type': 'string'}, + ] + }, + ) + + url = reverse('api:credential_type_detail', kwargs={'pk': ct.pk}) + response = get(url, admin) + assert response.status_code == 200 + + field_ids = [f['id'] for f in response.data['inputs']['fields']] + # Internal field should be filtered out + assert 'token' not in field_ids + assert 'url' in field_ids + assert 'public_field' in field_ids + + +@pytest.mark.django_db +def test_credential_type_list_filters_internal_fields(get, admin): + """Test that internal fields are filtered in list view.""" + CredentialType.objects.create( + kind='cloud', + name='Test OIDC Type', + managed=False, + inputs={ + 'fields': [ + {'id': 'url', 'label': 'URL', 'type': 'string'}, + {'id': 'workload_identity_token', 'label': 'Token', 'type': 'string', 'secret': True, 'internal': True}, + ] + }, + ) + + url = reverse('api:credential_type_list') + response = get(url, admin) + assert response.status_code == 200 + + # Find our credential type in the results + test_ct = next((ct for ct in response.data['results'] if ct['name'] == 'Test OIDC Type'), None) + assert test_ct is not None + + field_ids = [f['id'] for f in test_ct['inputs']['fields']] + # Internal field should be filtered out + assert 'workload_identity_token' not in field_ids + assert 'url' in field_ids diff --git a/awx/main/tests/functional/api/test_oidc_credential_test.py b/awx/main/tests/functional/api/test_oidc_credential_test.py new file mode 100644 index 000000000000..d7147a57fbb6 --- /dev/null +++ b/awx/main/tests/functional/api/test_oidc_credential_test.py @@ -0,0 +1,259 @@ +""" +Tests for OIDC workload identity credential test endpoints. + +Tests the /api/v2/credentials//test/ and /api/v2/credential_types//test/ +endpoints when used with OIDC-enabled credential types. +""" + +import pytest +from unittest import mock + +from django.test import override_settings + +from awx.main.models import Credential, CredentialType, JobTemplate +from awx.api.versioning import reverse + + +@pytest.fixture +def job_template(organization, project): + """Job template with organization and project for OIDC JWT generation.""" + return JobTemplate.objects.create(name='test-jt', organization=organization, project=project, playbook='helloworld.yml') + + +@pytest.fixture +def oidc_credentialtype(): + """Create a credential type with workload_identity_token internal field.""" + oidc_type_inputs = { + 'fields': [ + {'id': 'url', 'label': 'Vault URL', 'type': 'string', 'help_text': 'The Vault server URL.'}, + {'id': 'auth_path', 'label': 'Auth Path', 'type': 'string', 'help_text': 'JWT auth mount path.'}, + {'id': 'role_id', 'label': 'Role ID', 'type': 'string', 'help_text': 'Vault role.'}, + {'id': 'jwt_aud', 'label': 'JWT Audience', 'type': 'string', 'help_text': 'Expected audience.'}, + {'id': 'workload_identity_token', 'label': 'Workload Identity Token', 'type': 'string', 'secret': True, 'internal': True}, + ], + 'metadata': [ + {'id': 'secret_path', 'label': 'Secret Path', 'type': 'string'}, + {'id': 'job_template_id', 'label': 'Job Template ID', 'type': 'string'}, + ], + 'required': ['url', 'auth_path', 'role_id'], + } + + class MockPlugin(object): + def backend(self, **kwargs): + # Simulate successful backend call + return 'secret' + + with mock.patch('awx.main.models.credential.CredentialType.plugin', new_callable=mock.PropertyMock) as mock_plugin: + mock_plugin.return_value = MockPlugin() + oidc_type = CredentialType(kind='external', managed=True, namespace='hashivault-kv-oidc', name='HashiCorp Vault KV (OIDC)', inputs=oidc_type_inputs) + oidc_type.save() + yield oidc_type + + +@pytest.fixture +def oidc_credential(oidc_credentialtype): + """Create a credential using the OIDC credential type.""" + return Credential.objects.create( + credential_type=oidc_credentialtype, + name='oidc-vault-cred', + inputs={'url': 'http://vault.example.com:8200', 'auth_path': 'jwt', 'role_id': 'test-role', 'jwt_aud': 'vault'}, + ) + + +@pytest.fixture +def mock_oidc_backend(): + """Fixture that mocks OIDC JWT generation and credential backend.""" + with mock.patch('awx.api.views.retrieve_workload_identity_jwt_with_claims') as mock_jwt, mock.patch('awx.api.views._jwt_decode') as mock_decode, mock.patch( + 'awx.main.models.credential.CredentialType.plugin', new_callable=mock.PropertyMock + ) as mock_plugin: + + # Set default return values + mock_jwt.return_value = 'fake.jwt.token' + mock_decode.return_value = {'iss': 'http://gateway/o', 'aud': 'vault'} + + # Create mock backend + mock_backend = mock.MagicMock() + mock_backend.backend.return_value = 'secret' + mock_plugin.return_value = mock_backend + + # Yield all mocks for test customization + yield { + 'jwt': mock_jwt, + 'decode': mock_decode, + 'plugin': mock_plugin, + 'backend': mock_backend, + } + + +# --- Tests for CredentialExternalTest endpoint --- + + +@pytest.mark.django_db +@override_settings(FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED=False) +def test_credential_test_without_oidc_feature_flag(post, admin, oidc_credential): + """Test that credential test works without OIDC feature flag enabled.""" + url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk}) + data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': '1'}} + + with mock.patch('awx.main.models.credential.CredentialType.plugin', new_callable=mock.PropertyMock) as mock_plugin: + mock_backend = mock.MagicMock() + mock_backend.backend.return_value = 'secret' + mock_plugin.return_value = mock_backend + + response = post(url, data, admin) + assert response.status_code == 202 + # Should not contain JWT payload when feature flag is disabled + assert 'details' not in response.data or 'sent_jwt_payload' not in response.data.get('details', {}) + + +@pytest.mark.django_db +@mock.patch('awx.api.views.flag_enabled', return_value=True) +@pytest.mark.parametrize( + 'job_template_id, expected_error', + [ + (None, 'Job template ID is required'), + ('not-an-integer', 'must be an integer'), + ('99999', 'does not exist'), + ], + ids=['missing_job_template_id', 'invalid_job_template_id_type', 'nonexistent_job_template_id'], +) +def test_credential_test_job_template_validation(mock_flag, post, admin, oidc_credential, job_template_id, expected_error): + """Test that invalid job_template_id values return 400 with appropriate error messages.""" + url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk}) + data = {'metadata': {'secret_path': 'test/secret'}} + if job_template_id is not None: + data['metadata']['job_template_id'] = job_template_id + + response = post(url, data, admin) + assert response.status_code == 400 + assert 'details' in response.data + assert 'error_message' in response.data['details'] + assert expected_error in response.data['details']['error_message'] + + +@pytest.mark.django_db +@mock.patch('awx.api.views.flag_enabled', return_value=True) +def test_credential_test_no_access_to_job_template(mock_flag, post, alice, oidc_credential, job_template): + """Test that user without access to job template gets 403.""" + url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk}) + data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)}} + + # Give alice use permission on credential but not on job template + oidc_credential.use_role.members.add(alice) + + response = post(url, data, alice) + assert response.status_code == 403 + assert 'You do not have access to job template' in str(response.data) + + +@pytest.mark.django_db +@mock.patch('awx.api.views.flag_enabled', return_value=True) +def test_credential_test_success_returns_jwt_payload(mock_flag, post, admin, oidc_credential, job_template, mock_oidc_backend): + """Test that successful test returns JWT payload in response.""" + url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk}) + data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)}} + + # Customize mock for this test + mock_oidc_backend['decode'].return_value = { + 'iss': 'http://gateway/o', + 'sub': 'system:serviceaccount:default:awx-operator', + 'aud': 'vault', + 'job_template_id': job_template.id, + } + + response = post(url, data, admin) + assert response.status_code == 202 + assert 'details' in response.data + assert 'sent_jwt_payload' in response.data['details'] + assert response.data['details']['sent_jwt_payload']['job_template_id'] == job_template.id + + +@pytest.mark.django_db +@mock.patch('awx.api.views.flag_enabled', return_value=True) +def test_credential_test_backend_failure_returns_jwt_and_error(mock_flag, post, admin, oidc_credential, job_template, mock_oidc_backend): + """Test that backend failure still returns JWT payload along with error message.""" + url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk}) + data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)}} + + # Make backend fail + mock_oidc_backend['backend'].backend.side_effect = RuntimeError('Connection failed') + + response = post(url, data, admin) + assert response.status_code == 400 + assert 'details' in response.data + # Both JWT payload and error message should be present + assert 'sent_jwt_payload' in response.data['details'] + assert 'error_message' in response.data['details'] + assert 'Connection failed' in response.data['details']['error_message'] + + +@pytest.mark.django_db +@mock.patch('awx.api.views.flag_enabled', return_value=True) +def test_credential_test_jwt_generation_failure(mock_flag, post, admin, oidc_credential, job_template): + """Test that JWT generation failure returns error without JWT payload.""" + url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk}) + data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)}} + + with mock.patch('awx.api.views._get_workload_identity_token') as mock_jwt: + mock_jwt.side_effect = RuntimeError('Failed to generate JWT') + + response = post(url, data, admin) + assert response.status_code == 400 + assert 'details' in response.data + assert 'error_message' in response.data['details'] + assert 'Failed to generate JWT' in response.data['details']['error_message'] + # No JWT payload when generation fails + assert 'sent_jwt_payload' not in response.data['details'] + + +@pytest.mark.django_db +@mock.patch('awx.api.views.flag_enabled', return_value=True) +def test_credential_test_job_template_id_not_passed_to_backend(mock_flag, post, admin, oidc_credential, job_template, mock_oidc_backend): + """Test that job_template_id and jwt_aud are removed from backend_kwargs.""" + url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk}) + data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)}} + + response = post(url, data, admin) + assert response.status_code == 202 + + # Check that backend was called without job_template_id or jwt_aud + call_kwargs = mock_oidc_backend['backend'].backend.call_args[1] + assert 'job_template_id' not in call_kwargs + assert 'jwt_aud' not in call_kwargs + assert 'workload_identity_token' in call_kwargs + + +# --- Tests for CredentialTypeExternalTest endpoint --- + + +@pytest.mark.django_db +@mock.patch('awx.api.views.flag_enabled', return_value=True) +def test_credential_type_test_missing_job_template_id(mock_flag, post, admin, oidc_credentialtype): + """Test that missing job_template_id returns 400 for credential type test endpoint.""" + url = reverse('api:credential_type_external_test', kwargs={'pk': oidc_credentialtype.pk}) + data = { + 'inputs': {'url': 'http://vault.example.com:8200', 'auth_path': 'jwt', 'role_id': 'test-role', 'jwt_aud': 'vault'}, + 'metadata': {'secret_path': 'test/secret'}, + } + + response = post(url, data, admin) + assert response.status_code == 400 + assert 'details' in response.data + assert 'error_message' in response.data['details'] + assert 'Job template ID is required' in response.data['details']['error_message'] + + +@pytest.mark.django_db +@mock.patch('awx.api.views.flag_enabled', return_value=True) +def test_credential_type_test_success_returns_jwt_payload(mock_flag, post, admin, oidc_credentialtype, job_template, mock_oidc_backend): + """Test that successful credential type test returns JWT payload.""" + url = reverse('api:credential_type_external_test', kwargs={'pk': oidc_credentialtype.pk}) + data = { + 'inputs': {'url': 'http://vault.example.com:8200', 'auth_path': 'jwt', 'role_id': 'test-role', 'jwt_aud': 'vault'}, + 'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)}, + } + + response = post(url, data, admin) + assert response.status_code == 202 + assert 'details' in response.data + assert 'sent_jwt_payload' in response.data['details'] diff --git a/awx/main/tests/unit/tasks/test_jobs.py b/awx/main/tests/unit/tasks/test_jobs.py index bcc6f4d0fd52..e4df52b63f16 100644 --- a/awx/main/tests/unit/tasks/test_jobs.py +++ b/awx/main/tests/unit/tasks/test_jobs.py @@ -473,7 +473,7 @@ def test_populate_claims_for_adhoc_command(workload_attrs, expected_claims): assert claims == expected_claims -@mock.patch('awx.main.tasks.jobs.get_workload_identity_client') +@mock.patch('awx.main.utils.workload_identity.get_workload_identity_client') def test_retrieve_workload_identity_jwt_returns_jwt_from_client(mock_get_client): """retrieve_workload_identity_jwt returns the JWT string from the client.""" mock_client = mock.MagicMock() @@ -502,7 +502,7 @@ def test_retrieve_workload_identity_jwt_returns_jwt_from_client(mock_get_client) assert call_kwargs['claims'][AutomationControllerJobScope.CLAIM_JOB_NAME] == 'Test Job' -@mock.patch('awx.main.tasks.jobs.get_workload_identity_client') +@mock.patch('awx.main.utils.workload_identity.get_workload_identity_client') def test_retrieve_workload_identity_jwt_passes_audience_and_scope(mock_get_client): """retrieve_workload_identity_jwt passes audience and scope to the client.""" mock_client = mock.MagicMock() @@ -518,7 +518,7 @@ def test_retrieve_workload_identity_jwt_passes_audience_and_scope(mock_get_clien mock_client.request_workload_jwt.assert_called_once_with(claims={'job_id': 1}, scope=scope, audience=audience) -@mock.patch('awx.main.tasks.jobs.get_workload_identity_client') +@mock.patch('awx.main.utils.workload_identity.get_workload_identity_client') def test_retrieve_workload_identity_jwt_passes_workload_ttl(mock_get_client): """retrieve_workload_identity_jwt passes workload_ttl_seconds when provided.""" mock_client = mock.Mock() @@ -542,7 +542,7 @@ def test_retrieve_workload_identity_jwt_passes_workload_ttl(mock_get_client): ) -@mock.patch('awx.main.tasks.jobs.get_workload_identity_client') +@mock.patch('awx.main.utils.workload_identity.get_workload_identity_client') def test_retrieve_workload_identity_jwt_raises_when_client_not_configured(mock_get_client): """retrieve_workload_identity_jwt raises RuntimeError when client is None.""" mock_get_client.return_value = None diff --git a/awx/main/utils/workload_identity.py b/awx/main/utils/workload_identity.py index e69de29bb2d1..50582e224597 100644 --- a/awx/main/utils/workload_identity.py +++ b/awx/main/utils/workload_identity.py @@ -0,0 +1,22 @@ +from ansible_base.resource_registry.workload_identity_client import get_workload_identity_client + +__all__ = ['retrieve_workload_identity_jwt_with_claims'] + + +def retrieve_workload_identity_jwt_with_claims( + claims: dict, + audience: str, + scope: str, + workload_ttl_seconds: int | None = None, +) -> str: + """Retrieve JWT token from workload claims. + Raises: + RuntimeError: if the workload identity client is not configured. + """ + client = get_workload_identity_client() + if client is None: + raise RuntimeError("Workload identity client is not configured") + kwargs = {"claims": claims, "scope": scope, "audience": audience} + if workload_ttl_seconds: + kwargs["workload_ttl_seconds"] = workload_ttl_seconds + return client.request_workload_jwt(**kwargs).jwt