Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions awx/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2932,6 +2932,19 @@ def to_representation(self, data):
field['label'] = _(field['label'])
if 'help_text' in field:
field['help_text'] = _(field['help_text'])

# Deep copy inputs to avoid modifying the original model data
inputs = value.get('inputs')
if not isinstance(inputs, dict):
inputs = {}
value['inputs'] = copy.deepcopy(inputs)
fields = value['inputs'].get('fields', [])
if not isinstance(fields, list):
fields = []

# Filter out internal fields from the API response
value['inputs']['fields'] = [f for f in fields if not f.get('internal')]

return value

def filter_field_metadata(self, fields, method):
Expand Down
99 changes: 92 additions & 7 deletions awx/api/views/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import time
from base64 import b64encode
from collections import OrderedDict
from jwt import decode as _jwt_decode

from urllib3.exceptions import ConnectTimeoutError

Expand Down Expand Up @@ -58,9 +59,11 @@
from ansible_base.lib.utils.requests import get_remote_hosts
from ansible_base.rbac.models import RoleEvaluation
from ansible_base.lib.utils.schema import extend_schema_if_available
from ansible_base.lib.workload_identity.controller import AutomationControllerJobScope

# AWX
from awx.main.tasks.system import send_notifications, update_inventory_computed_fields
from awx.main.tasks.jobs import retrieve_workload_identity_jwt_with_claims
from awx.main.tasks.system import flag_enabled, send_notifications, update_inventory_computed_fields
from awx.main.access import get_user_queryset
from awx.api.generics import (
APIView,
Expand Down Expand Up @@ -163,6 +166,72 @@ def api_exception_handler(exc, context):
return exception_handler(exc, context)


def _get_workload_identity_token(job_template: models.JobTemplate, jwt_aud: str) -> str:
claims = {
AutomationControllerJobScope.CLAIM_ORGANIZATION_NAME: job_template.organization.name,
AutomationControllerJobScope.CLAIM_ORGANIZATION_ID: job_template.organization.id,
AutomationControllerJobScope.CLAIM_PROJECT_NAME: job_template.project.name,
AutomationControllerJobScope.CLAIM_PROJECT_ID: job_template.project.id,
AutomationControllerJobScope.CLAIM_JOB_TEMPLATE_NAME: job_template.name,
AutomationControllerJobScope.CLAIM_JOB_TEMPLATE_ID: job_template.id,
AutomationControllerJobScope.CLAIM_PLAYBOOK_NAME: job_template.playbook,
}
# Get a Workload Identity Token
return retrieve_workload_identity_jwt_with_claims(
claims=claims,
audience=jwt_aud,
scope=AutomationControllerJobScope.name,
)


def _handle_oidc_credential_test(credential_type_inputs, backend_kwargs, request):
"""
Handle OIDC workload identity token generation for external credential test endpoints.

Returns:
tuple: (response_body dict, error Response or None)
Modifies backend_kwargs in place to add workload_identity_token.
"""
response_body = {}

if not flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'):
return response_body, None

# Get a Workload Identity Token if credential contains an internal 'workload_identity_token' field
fields = credential_type_inputs.get('fields', [])
for field in fields:
if field.get('internal') and field.get('id') == 'workload_identity_token':
# Make sure that the requesting user has access to the job template
job_template_id = backend_kwargs.pop('job_template_id', None)
if job_template_id is None:
response_body['details'] = {'error_message': _('Job template ID is required.')}
return response_body, Response(response_body, status=status.HTTP_400_BAD_REQUEST)

try:
job_template = models.JobTemplate.objects.get(id=int(job_template_id))
except ValueError:
response_body['details'] = {'error_message': _('Job template ID must be an integer.')}
return response_body, Response(response_body, status=status.HTTP_400_BAD_REQUEST)
except models.JobTemplate.DoesNotExist:
response_body['details'] = {'error_message': _('Job template with ID %(id)s does not exist.') % {'id': job_template_id}}
return response_body, Response(response_body, status=status.HTTP_400_BAD_REQUEST)

if not request.user.can_access(models.JobTemplate, 'read', job_template):
raise PermissionDenied(_('You do not have access to job template with id: %(id)s.') % {'id': job_template.id})

try:
backend_kwargs['workload_identity_token'] = _get_workload_identity_token(job_template, backend_kwargs.pop('jwt_aud', None))
except RuntimeError as exc:
response_body['details'] = {'error_message': str(exc)}
return response_body, Response(response_body, status=status.HTTP_400_BAD_REQUEST)

response_body['details'] = {
'sent_jwt_payload': _jwt_decode(backend_kwargs['workload_identity_token'], algorithms=["RS256"], options={"verify_signature": False}),
}

return response_body, None


class DashboardView(APIView):
deprecated = True

Expand Down Expand Up @@ -1622,23 +1691,31 @@ def post(self, request, *args, **kwargs):
if value != '$encrypted$':
backend_kwargs[field_name] = value
backend_kwargs.update(request.data.get('metadata', {}))

# Add extra test functionality for OIDC-enabled credential types
response_body, error_response = _handle_oidc_credential_test(obj.credential_type.inputs, backend_kwargs, request)
if error_response:
return error_response

try:
with set_environ(**settings.AWX_TASK_ENV):
obj.credential_type.plugin.backend(**backend_kwargs)
return Response({}, status=status.HTTP_202_ACCEPTED)
return Response(response_body, status=status.HTTP_202_ACCEPTED)
except requests.exceptions.HTTPError:
message = """Test operation is not supported for credential type {}.
This endpoint only supports credentials that connect to
external secret management systems such as CyberArk, HashiCorp
Vault, or cloud-based secret managers.""".format(obj.credential_type.kind)
return Response({'detail': message}, status=status.HTTP_400_BAD_REQUEST)
response_body.setdefault('details', {})['error_message'] = message
return Response(response_body, status=status.HTTP_400_BAD_REQUEST)
except Exception as exc:
message = exc.__class__.__name__
exc_args = getattr(exc, 'args', [])
for a in exc_args:
if isinstance(getattr(a, 'reason', None), ConnectTimeoutError):
message = str(a.reason)
return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST)
response_body.setdefault('details', {})['error_message'] = message
return Response(response_body, status=status.HTTP_400_BAD_REQUEST)


class CredentialInputSourceDetail(RetrieveUpdateDestroyAPIView):
Expand Down Expand Up @@ -1685,19 +1762,27 @@ def post(self, request, *args, **kwargs):
obj = self.get_object()
backend_kwargs = request.data.get('inputs', {})
backend_kwargs.update(request.data.get('metadata', {}))

# Add extra test functionality for OIDC-enabled credential types
response_body, error_response = _handle_oidc_credential_test(obj.inputs, backend_kwargs, request)
if error_response:
return error_response

Comment on lines +1765 to +1770
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is duplicated in both view implementations. Can we refactor this so we don't have two copies of the same thing?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored into helper function.

try:
obj.plugin.backend(**backend_kwargs)
return Response({}, status=status.HTTP_202_ACCEPTED)
return Response(response_body, status=status.HTTP_202_ACCEPTED)
except requests.exceptions.HTTPError as exc:
message = 'HTTP {}'.format(exc.response.status_code)
return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST)
response_body.setdefault('details', {})['error_message'] = message
return Response(response_body, status=status.HTTP_400_BAD_REQUEST)
except Exception as exc:
message = exc.__class__.__name__
args_exc = getattr(exc, 'args', [])
for a in args_exc:
if isinstance(getattr(a, 'reason', None), ConnectTimeoutError):
message = str(a.reason)
return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST)
response_body.setdefault('details', {})['error_message'] = message
return Response(response_body, status=status.HTTP_400_BAD_REQUEST)


class HostRelatedSearchMixin(object):
Expand Down
16 changes: 7 additions & 9 deletions awx/main/tasks/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@

# Workload Identity
from ansible_base.lib.workload_identity.controller import AutomationControllerJobScope
from ansible_base.resource_registry.workload_identity_client import get_workload_identity_client
from awx.main.utils.workload_identity import retrieve_workload_identity_jwt_with_claims

logger = logging.getLogger('awx.main.tasks.jobs')

Expand Down Expand Up @@ -168,14 +168,12 @@ def retrieve_workload_identity_jwt(
Raises:
RuntimeError: if the workload identity client is not configured.
"""
client = get_workload_identity_client()
if client is None:
raise RuntimeError("Workload identity client is not configured")
claims = populate_claims_for_workload(unified_job)
kwargs = {"claims": claims, "scope": scope, "audience": audience}
if workload_ttl_seconds:
kwargs["workload_ttl_seconds"] = workload_ttl_seconds
return client.request_workload_jwt(**kwargs).jwt
return retrieve_workload_identity_jwt_with_claims(
populate_claims_for_workload(unified_job),
audience,
scope,
workload_ttl_seconds,
)


def with_path_cleanup(f):
Expand Down
23 changes: 23 additions & 0 deletions awx/main/utils/workload_identity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from ansible_base.resource_registry.workload_identity_client import get_workload_identity_client


__all__ = ['retrieve_workload_identity_jwt_with_claims']


def retrieve_workload_identity_jwt_with_claims(
claims: dict,
audience: str,
scope: str,
workload_ttl_seconds: int | None = None,
) -> str:
"""Retrieve JWT token from workload claims.
Raises:
RuntimeError: if the workload identity client is not configured.
"""
client = get_workload_identity_client()
if client is None:
raise RuntimeError("Workload identity client is not configured")
kwargs = {"claims": claims, "scope": scope, "audience": audience}
if workload_ttl_seconds:
kwargs["workload_ttl_seconds"] = workload_ttl_seconds
return client.request_workload_jwt(**kwargs).jwt
Loading