Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions awx/main/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def _load_credential_types_feature(self):

@bypass_in_test
def load_credential_types_feature(self):
from awx.main.models.credential import load_credentials

load_credentials()
return self._load_credential_types_feature()

def load_inventory_plugins(self):
Expand Down
83 changes: 49 additions & 34 deletions awx/main/models/credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# All Rights Reserved.
from contextlib import nullcontext
import functools

import inspect
import logging
from importlib.metadata import entry_points
Expand Down Expand Up @@ -47,6 +48,8 @@
)
from awx.main.models import Team, Organization
from awx.main.utils import encrypt_field
from awx_plugins.interfaces._temporary_private_licensing_api import detect_server_product_name


# DAB
from ansible_base.resource_registry.tasks.sync import get_resource_server_client
Expand All @@ -56,7 +59,6 @@
__all__ = ['Credential', 'CredentialType', 'CredentialInputSource', 'build_safe_env']

logger = logging.getLogger('awx.main.models.credential')
credential_plugins = {entry_point.name: entry_point.load() for entry_point in entry_points(group='awx_plugins.credentials')}

HIDDEN_PASSWORD = '**********'

Expand Down Expand Up @@ -462,8 +464,7 @@ def askable_fields(self):
def plugin(self):
if self.kind != 'external':
raise AttributeError('plugin')
[plugin] = [plugin for ns, plugin in credential_plugins.items() if ns == self.namespace]
return plugin
return ManagedCredentialType.registry.get(self.namespace, None)

def default_for_field(self, field_id):
for field in self.inputs.get('fields', []):
Expand All @@ -474,7 +475,7 @@ def default_for_field(self, field_id):

@classproperty
def defaults(cls):
return dict((k, functools.partial(v.create)) for k, v in ManagedCredentialType.registry.items())
return dict((k, functools.partial(CredentialTypeHelper.create, v)) for k, v in ManagedCredentialType.registry.items())

@classmethod
def _get_credential_type_class(cls, apps: Apps = None, app_config: AppConfig = None):
Expand Down Expand Up @@ -509,7 +510,7 @@ def _setup_tower_managed_defaults(cls, apps: Apps = None, app_config: AppConfig
existing.save()
continue
logger.debug(_("adding %s credential type" % default.name))
params = default.get_creation_params()
params = CredentialTypeHelper.get_creation_params(default)
if 'managed' not in [f.name for f in ct_class._meta.get_fields()]:
params['managed_by_tower'] = params.pop('managed')
params['created'] = params['modified'] = now() # CreatedModifiedModel service
Expand Down Expand Up @@ -543,46 +544,37 @@ def setup_tower_managed_defaults(cls, apps: Apps = None, app_config: AppConfig =
@classmethod
def load_plugin(cls, ns, plugin):
# TODO: User "side-loaded" credential custom_injectors isn't supported
ManagedCredentialType(namespace=ns, name=plugin.name, kind='external', inputs=plugin.inputs)
ManagedCredentialType.registry[ns] = ManagedCredentialType(namespace=ns, name=plugin.name, kind='external', inputs=plugin.inputs, injectors={})

def inject_credential(self, credential, env, safe_env, args, private_data_dir):
from awx_plugins.interfaces._temporary_private_inject_api import inject_credential

inject_credential(self, credential, env, safe_env, args, private_data_dir)


class ManagedCredentialType(SimpleNamespace):
registry = {}

def __init__(self, namespace, **kwargs):
for k in ('inputs', 'injectors'):
if k not in kwargs:
kwargs[k] = {}
super(ManagedCredentialType, self).__init__(namespace=namespace, **kwargs)
if namespace in ManagedCredentialType.registry:
raise ValueError(
'a ManagedCredentialType with namespace={} is already defined in {}'.format(
namespace, inspect.getsourcefile(ManagedCredentialType.registry[namespace].__class__)
)
)
ManagedCredentialType.registry[namespace] = self

def get_creation_params(self):
class CredentialTypeHelper:
@classmethod
def get_creation_params(cls, cred_type):
return dict(
namespace=self.namespace,
kind=self.kind,
name=self.name,
namespace=cred_type.namespace,
kind=cred_type.kind,
name=cred_type.name,
managed=True,
inputs=self.inputs,
injectors=self.injectors,
inputs=cred_type.inputs,
injectors=cred_type.injectors,
)

def create(self):
res = CredentialType(**self.get_creation_params())
res.custom_injectors = getattr(self, 'custom_injectors', None)
@classmethod
def create(cls, cred_type):
res = CredentialType(**CredentialTypeHelper.get_creation_params(cred_type))
res.custom_injectors = getattr(cred_type, "custom_injectors", None)
return res


class ManagedCredentialType(SimpleNamespace):
registry = {}


class CredentialInputSource(PrimordialModel):
class Meta:
app_label = 'main'
Expand Down Expand Up @@ -647,7 +639,30 @@ def get_absolute_url(self, request=None):
return reverse(view_name, kwargs={'pk': self.pk}, request=request)


from awx_plugins.credentials.plugins import * # noqa
def load_credentials():

awx_entry_points = {ep.name: ep for ep in entry_points(group='awx_plugins.managed_credentials')}
supported_entry_points = {ep.name: ep for ep in entry_points(group='awx_plugins.managed_credentials.supported')}
plugin_entry_points = awx_entry_points if detect_server_product_name() == 'AWX' else {**awx_entry_points, **supported_entry_points}

for ns, ep in plugin_entry_points.items():
cred_plugin = ep.load()
if not hasattr(cred_plugin, 'inputs'):
setattr(cred_plugin, 'inputs', {})
if not hasattr(cred_plugin, 'injectors'):
setattr(cred_plugin, 'injectors', {})
if ns in ManagedCredentialType.registry:
raise ValueError(
'a ManagedCredentialType with namespace={} is already defined in {}'.format(
ns, inspect.getsourcefile(ManagedCredentialType.registry[ns].__class__)
)
)
ManagedCredentialType.registry[ns] = cred_plugin

credential_plugins = {ep.name: ep for ep in entry_points(group='awx_plugins.credentials')}
if detect_server_product_name() == 'AWX':
credential_plugins = {}

for ns, plugin in credential_plugins.items():
CredentialType.load_plugin(ns, plugin)
for ns, ep in credential_plugins.items():
plugin = ep.load()
CredentialType.load_plugin(ns, plugin)
9 changes: 9 additions & 0 deletions awx/main/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,12 @@ def me_inst():
me_mock = mock.MagicMock(return_value=inst)
with mock.patch.object(Instance.objects, 'me', me_mock):
yield inst


@pytest.fixture(scope="session", autouse=True)
def load_all_credentials():
with mock.patch('awx.main.models.credential.detect_server_product_name', return_value='NOT_AWX'):
from awx.main.models.credential import load_credentials

load_credentials()
yield
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,24 @@ def cleanup_cloudforms():
assert 'cloudforms' not in CredentialType.defaults


@pytest.mark.django_db
def test_cloudforms_inventory_removal(request, inventory):
request.addfinalizer(cleanup_cloudforms)
ManagedCredentialType(
@pytest.fixture
def cloudforms_mct():
ManagedCredentialType.registry['cloudforms'] = ManagedCredentialType(
name='Red Hat CloudForms',
namespace='cloudforms',
kind='cloud',
managed=True,
inputs={},
injectors={},
)
yield
ManagedCredentialType.registry.pop('cloudforms', None)


@pytest.mark.django_db
def test_cloudforms_inventory_removal(request, inventory, cloudforms_mct):
request.addfinalizer(cleanup_cloudforms)

CredentialType.defaults['cloudforms']().save()
cloudforms = CredentialType.objects.get(namespace='cloudforms')
Credential.objects.create(
Expand Down
1 change: 1 addition & 0 deletions awx/main/tests/live/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# These tests are invoked from the awx/main/tests/live/ subfolder
# so any fixtures from higher-up conftest files must be explicitly included
from awx.main.tests.functional.conftest import * # noqa
from awx.main.tests.conftest import load_all_credentials # noqa: F401; pylint: disable=unused-import

from awx.main.models import Organization

Expand Down
1 change: 1 addition & 0 deletions awx_collection/test/awx/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ansible.module_utils.six import raise_from

from ansible_base.rbac.models import RoleDefinition, DABPermission
from awx.main.tests.conftest import load_all_credentials # noqa: F401; pylint: disable=unused-import
from awx.main.tests.functional.conftest import _request
from awx.main.tests.functional.conftest import credentialtype_scm, credentialtype_ssh # noqa: F401; pylint: disable=unused-import
from awx.main.models import (
Expand Down
Loading