From 2d26b5a3a033099f07e5c31559d6c69bd9dbfc8a Mon Sep 17 00:00:00 2001 From: Aleksandr Cupacenko Date: Thu, 19 Mar 2020 10:21:21 +0100 Subject: [PATCH 1/2] move hosts management to AnsibleHosts class to enable large inventories (10K+) file linear system based processing of parsed hosts data to reduce memory usage. --- src/ansiblecmdb/ansible.py | 83 ++--------- src/ansiblecmdb/ansiblehosts.py | 242 ++++++++++++++++++++++++++++++++ 2 files changed, 253 insertions(+), 72 deletions(-) create mode 100644 src/ansiblecmdb/ansiblehosts.py diff --git a/src/ansiblecmdb/ansible.py b/src/ansiblecmdb/ansible.py index f1c22d2..d2b89b6 100644 --- a/src/ansiblecmdb/ansible.py +++ b/src/ansiblecmdb/ansible.py @@ -7,6 +7,7 @@ from . import ihateyaml import ansiblecmdb.util as util import ansiblecmdb.parser as parser +from ansiblecmdb.ansiblehosts import AnsibleHosts def strip_exts(s, exts): @@ -43,9 +44,9 @@ def __init__(self, fact_dirs, inventory_paths=None, fact_cache=False, else: self.inventory_paths = inventory_paths self.fact_cache = fact_cache # fact dirs are fact-caches - self.limit = self._parse_limit(limit) self.debug = debug - self.hosts = {} + self.hosts = AnsibleHosts() + self.hosts.setLimit(limit) self.log = logging.getLogger(__name__) # Process facts gathered by Ansible's setup module of fact caching. @@ -64,28 +65,6 @@ def __init__(self, fact_dirs, inventory_paths=None, fact_cache=False, for inventory_path in self.inventory_paths: self._parse_groupvar_dir(inventory_path) - def _parse_limit(self, limit): - """ - Parse a host / group limit in the form of a string (e.g. - 'all:!cust.acme') into a dict of things to be included and things to be - excluded. - """ - if limit is None: - return None - - limit_parsed = { - "include": [], - "exclude": [] - } - elems = limit.split(":") - for elem in elems: - if elem.startswith('!'): - limit_parsed['exclude'].append(elem[1:]) - else: - limit_parsed['include'].append(elem) - - return limit_parsed - def _handle_inventory(self, inventory_path): """ Scan inventory. As Ansible is a big mess without any kind of @@ -326,65 +305,25 @@ def _parse_dyn_inventory(self, script): def update_host(self, hostname, key_values, overwrite=True): """ - Update a hosts information. This is called by various collectors such - as the ansible setup module output and the hosts parser to add - informatio to a host. It does some deep inspection to make sure nested - information can be updated. + Let hosts object update itself """ - default_empty_host = { - 'name': hostname, - 'hostvars': {}, - } - host_info = self.hosts.get(hostname, default_empty_host) - util.deepupdate(host_info, key_values, overwrite=overwrite) - self.hosts[hostname] = host_info + self.hosts.update_host(hostname, key_values, overwrite) def hosts_all(self): """ - Return a list of all hostnames. + Let the hosts object return all a list of all hostnames """ - return [hostname for hostname, hostinfo in self.hosts.items()] + return self.hosts.hosts_all() def hosts_in_group(self, groupname): """ - Return a list of hostnames that are in a group. + Let the hosts object return a list of hostnames that are in a group. """ - result = [] - for hostname, hostinfo in self.hosts.items(): - if groupname == 'all': - result.append(hostname) - elif 'groups' in hostinfo: - if groupname in hostinfo['groups']: - result.append(hostname) - else: - hostinfo['groups'] = [groupname] - return result + return self.hosts.hosts_in_group(groupname) def get_hosts(self): """ Return a list of parsed hosts info, with the limit applied if required. + Limits are applied at runtime on the host object level so just return the object """ - limited_hosts = {} - if self.limit is not None: - # Find hosts and groups of hosts to include - for include in self.limit['include']: - # Include whole group - for hostname in self.hosts_in_group(include): - limited_hosts[hostname] = self.hosts[hostname] - # Include individual host - if include in self.hosts: - limited_hosts[include] = self.hosts[include] - # Find hosts and groups of hosts to exclude - for exclude in self.limit["exclude"]: - # Exclude whole group - for hostname in self.hosts_in_group(exclude): - if hostname in limited_hosts: - limited_hosts.pop(hostname) - # Exclude individual host - if exclude in limited_hosts: - limited_hosts.pop(exclude) - - return limited_hosts - else: - # Return all hosts - return self.hosts + return self.hosts diff --git a/src/ansiblecmdb/ansiblehosts.py b/src/ansiblecmdb/ansiblehosts.py new file mode 100644 index 0000000..f656d1b --- /dev/null +++ b/src/ansiblecmdb/ansiblehosts.py @@ -0,0 +1,242 @@ +import os +import tempfile +import shutil +import logging +import base64 +import codecs +import pickle +import ansiblecmdb.util as util + +class AnsibleHosts(object): + """ + Container class to store and retrieve ansible hosts as objects. + Implements dictionary functionality by overriding its default methods + Serves as a replacement for single dictionary of hosts and their attributes + so that hosts data is linerly fetched from disk rather than keeping the whole + dictionary in memory. Enables processing of 10K+ hosts inventories. + """ + def __init__(self): + self.log = logging.getLogger(__name__) + # set include/exclude limit + self.limit = None + # create temporary file + self._tmp_dir = tempfile.mkdtemp() + self.log.debug("Created temporary directory {0}".format(self._tmp_dir)) + + def __del__(self): + # delete temporary directory + self.log.debug("Removing temporary directory {0}".format(self._tmp_dir)) + shutil.rmtree(self._tmp_dir) + + def __getitem__(self, key): + """ + Enable dictionary-style access to the items, ex. hosts['example.com'] + """ + filename = os.path.join(self._tmp_dir, self._hostname2filename(key)) + if os.path.isfile(filename): + host_data = self._load_data_from_file(filename) + if self._host_matches_limits(host_data): + return host_data + else: + raise KeyError('{0} host does not match limits.'.format(key)) + else: + raise KeyError('{0} does not exist.'.format(key)) + + def get(self, key, default={}): + """ + Get single host data as a dictionary using hostname + """ + filename = os.path.join(self._tmp_dir, self._hostname2filename(key)) + if os.path.isfile(filename): + host_data = self._load_data_from_file(filename) + if self._host_matches_limits(host_data): + return host_data + return default + + def __contains__(self, key): + """ + Enable dictionary-style check for key existence, ex. 'example.com' in hosts + """ + filename = os.path.join(self._tmp_dir, self._hostname2filename(key)) + return os.path.isfile(filename) + + def __setitem__(self, key, value): + """ + Enable dictionary-style value update, ex. hosts['example.com'] = newValue + """ + self.update_host(key, value) + + def __iter__(self): + """ + Make the object iterable + """ + return iter(self.items()) + + def __len__(self): + """ + Implement len method for the dictionary object + """ + i = 0 + for name in os.listdir(self._tmp_dir): + if os.path.isfile(os.path.join(self._tmp_dir, name)): + i += 1 + return i + + def update(self, other=None, **kwargs): + """ + Override default dictionary update method + """ + if other is not None: + if isinstance(other, dict): + for k, v in other.items(): + self._set_host_data(k, v) + else: + for k, v in other: + self._set_host_data(k, v) + + for k, v in kwargs.items(): + self._set_host_data(k, v) + + def update_host(self, hostname, key_values, overwrite=True): + """ + Update a hosts information. This is called by various collectors such + as the ansible setup module output and the hosts parser to add + informatio to a host. It does some deep inspection to make sure nested + information can be updated. + """ + + default_empty_host_data = { + 'name': hostname, + 'hostvars': {}, + } + + host_data = self.get(hostname, default_empty_host_data) + util.deepupdate(host_data, key_values, overwrite=overwrite) + + self._set_host_data(hostname, host_data) + + def items(self): + """ + Walk through the temporary directory and yield hosts data one by one + """ + for name in os.listdir(self._tmp_dir): + filename = os.path.join(self._tmp_dir, name) + if os.path.isfile(filename): + host_data = self._load_data_from_file(filename) + if self._host_matches_limits(host_data): + hostname = host_data['name'] + yield(hostname, host_data) + + def _hostname2filename(self, hostname): + """ + Create host storage file from its hostname but convert to base64 to + filter non ascii characters + """ + filename = base64.urlsafe_b64encode(hostname.encode('utf-8')).decode() + return filename + + def _parse_limit(self, limit): + """ + Parse a host / group limit in the form of a string (e.g. + 'all:!cust.acme') into a dict of things to be included and things to be + excluded. + """ + + limit_parsed = { + 'include': [], + 'exclude': [] + } + + elems = limit.split(":") + for elem in elems: + if elem.startswith('!'): + limit_parsed['exclude'].append(elem[1:]) + else: + limit_parsed['include'].append(elem) + + self.log.debug("Hosts limits applied: {0}".format(limit_parsed)) + return limit_parsed + + def _host_matches_limits(self, host): + """ + Test if the host satisfies given include/exclude limits + """ + + # return true if limit is not set + if self.limit is None: + return True + + # add hostname and host groups to a single names list + names = [] + if 'groups' in host: + names = host['groups'] + names.append(host['name']) + + # return false if hostname or group is in exclude list + for exclude in self.limit['exclude']: + if exclude in names: + return False + + # return true of include list is empty + if not self.limit['include']: + return True + + # return true if hostname or group matches include list + for include in self.limit['include']: + if include in names: + return True + + # return false if include limit not matched + return False + + def setLimit(self, limit): + """ + Set include/exclude limit to filter the hosts returned from get(), items(), etc. + """ + if limit is None: + return + self.limit = self._parse_limit(limit) + + def hosts_all(self): + """ + Return a list of all hostnames. + """ + for hostname, hostinfo in self.items(): + yield hostname + + def hosts_in_group(self, groupname): + """ + Return a list of hostnames that are in a group. + """ + if groupname == 'all': + for hostname, hostinfo in self.items(): + yield hostname + else: + for hostname, hostinfo in self.items(): + if 'groups' in hostinfo: + if groupname in hostinfo['groups']: + yield hostname + + def _set_host_data(self, hostname, host_data): + """ + Check if host data matched the limit and save it to file + """ + if self._host_matches_limits(host_data): + filename = os.path.join(self._tmp_dir, self._hostname2filename(hostname)) + self._save_data_to_file(host_data, filename) + + def _save_data_to_file(self, data, filename): + """ + Write host data to file + """ + with codecs.open(filename, 'wb') as handle: + self.log.debug("Writing host data to file: {0}".format(filename)) + pickle.dump(data, handle) + + def _load_data_from_file(self, filename): + """ + Load host data from file + """ + self.log.debug("Reading host data from file {0}".format(filename)) + with codecs.open(filename, 'rb') as handle: + return pickle.load(handle) From 7871796a6ca154ee4f33fc0c164b9bf1621c518a Mon Sep 17 00:00:00 2001 From: Aleksandr Cupacenko Date: Thu, 19 Mar 2020 10:25:14 +0100 Subject: [PATCH 2/2] add tests for AnsibleHosts class --- test/f_ansiblehosts/hosts | 14 +++++++++ test/f_ansiblehosts/out/db.dev.local | 9 ++++++ test/test.py | 45 ++++++++++++++++++++++++++++ 3 files changed, 68 insertions(+) create mode 100644 test/f_ansiblehosts/hosts create mode 100644 test/f_ansiblehosts/out/db.dev.local diff --git a/test/f_ansiblehosts/hosts b/test/f_ansiblehosts/hosts new file mode 100644 index 0000000..a2958e2 --- /dev/null +++ b/test/f_ansiblehosts/hosts @@ -0,0 +1,14 @@ +[web] +web01.local +web02.local +web03.local + +[db] +db01.local +db02.local +db03.local + +[app] +app01.local +app02.local +app03.local diff --git a/test/f_ansiblehosts/out/db.dev.local b/test/f_ansiblehosts/out/db.dev.local new file mode 100644 index 0000000..d5e4e95 --- /dev/null +++ b/test/f_ansiblehosts/out/db.dev.local @@ -0,0 +1,9 @@ +{ + "ansible_facts": { + "ansible_fqdn": "localhost", + "ansible_hostname": "dev", + "ansible_nodename": "db.dev.local", + "module_setup": true + }, + "changed": false +} diff --git a/test/test.py b/test/test.py index 667c7a1..d1609c8 100644 --- a/test/test.py +++ b/test/test.py @@ -132,6 +132,51 @@ def testFactCache(self): self.assertIn('ansible_env', ansible_facts) +class AnsibleHostsTestCase(unittest.TestCase): + """ + Test that AnsibleHosts class required functionality + """ + fact_dirs = ['f_ansiblehosts/out'] + inventories = ['f_ansiblehosts/hosts'] + ansible = ansiblecmdb.Ansible(fact_dirs, inventories) + + def testAnsibleHostsLen(self): + self.assertEqual(len(self.ansible.hosts), 10) + + def testAnsibleHostsGetItem(self): + self.assertIn('web01.local', self.ansible.hosts) + self.assertTrue(self.ansible.hosts['web01.local']) + self.assertTrue(self.ansible.hosts.get('web01.local')) + self.assertFalse(self.ansible.hosts.get('nonexistent')) + + def testAnsibleHostsIterable(self): + try: + iterator = iter(self.ansible.hosts.items()) + except TypeError: + self.fail('AnsibleHosts object is not iterable') + + def testAnsibleHostsSetItem(self): + self.ansible.hosts['newhost'] = dict(name='newhost', hostvars={}) + self.assertIn('newhost', self.ansible.hosts) + + def testAnsibleHostsUpdate(self): + update_value = dict( { 'hostvars': { 'department': 'finance' } } ) + self.ansible.hosts.update_host('web01.local', update_value) + self.assertEqual(self.ansible.hosts['web01.local']['hostvars']['department'], 'finance') + + # update dictionary method will reset the value + update_item = dict( { 'web01.local': { 'hostvars': { 'location': 'us-east' } } } ) + self.ansible.hosts.update(update_item) + self.assertEqual(self.ansible.hosts['web01.local']['hostvars']['location'], 'us-east') + self.assertNotIn('department', self.ansible.hosts['web01.local']['hostvars']) + + # update_host method will only update individual tree leaves + update_value = dict( { 'hostvars': { 'department': 'finance' } } ) + self.ansible.hosts.update_host('web01.local', update_value) + self.assertEqual(self.ansible.hosts['web01.local']['hostvars']['department'], 'finance') + self.assertEqual(self.ansible.hosts['web01.local']['hostvars']['location'], 'us-east') + + if __name__ == '__main__': unittest.main(exit=True)