From 3b3a61ceaf29bde851cd058620d024193fb9f19a Mon Sep 17 00:00:00 2001 From: Jakob Jordan Date: Wed, 16 May 2018 16:37:10 +0200 Subject: [PATCH 1/2] Provide generation and individual id to objective in snes --- es/separable_natural_es.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/es/separable_natural_es.py b/es/separable_natural_es.py index a1bd3de..7dc8b61 100644 --- a/es/separable_natural_es.py +++ b/es/separable_natural_es.py @@ -2,6 +2,7 @@ import numpy as np from . import lib + def optimize(func, mu, sigma, learning_rate_mu=None, learning_rate_sigma=None, population_size=None, max_iter=2000, @@ -45,13 +46,13 @@ def optimize(func, mu, sigma, z = np.vstack([z, mu - sigma * s]) s = np.vstack([s, -s]) + generations_list = [generation] * len(z) + individual_list = range(len(z)) if parallel_threads is None: - fitness = np.fromiter((func(zi) for zi in z), np.float) + fitness = np.fromiter((func(zi, gi, ii) for zi, gi, ii in zip(z, generations_list, individual_list)), np.float) else: - pool = mp.Pool(processes=parallel_threads) - fitness = np.fromiter(pool.map(func, z), np.float) - pool.close() - pool.join() + with mp.Pool(processes=parallel_threads) as pool: + fitness = np.fromiter(pool.starmap(func, zip(z, generations_list, individual_list)), np.float) ni = np.logical_not(np.isnan(fitness)) z = z[ni] @@ -87,4 +88,3 @@ def optimize(func, mu, sigma, 'history_sigma': history_sigma, 'history_fitness': history_fitness, 'history_pop': history_pop} - \ No newline at end of file From b4046eb62303bb6577811e4356e326ec8a904888 Mon Sep 17 00:00:00 2001 From: Jakob Jordan Date: Tue, 29 May 2018 10:27:48 +0200 Subject: [PATCH 2/2] Basic implementation of checkpointing --- es/lib.py | 10 +++++++ es/separable_natural_es.py | 53 +++++++++++++++++++++++++++++++++----- 2 files changed, 56 insertions(+), 7 deletions(-) diff --git a/es/lib.py b/es/lib.py index 00cea09..5aa5fc6 100644 --- a/es/lib.py +++ b/es/lib.py @@ -1,4 +1,14 @@ import numpy as np +import pickle + + +def create_results_dict(mu, sigma, history_mu, history_sigma, history_fitness, history_pop): + return {'mu': mu, + 'sigma': sigma, + 'history_mu': history_mu, + 'history_sigma': history_sigma, + 'history_fitness': history_fitness, + 'history_pop': history_pop} def default_population_size(dimensions): diff --git a/es/separable_natural_es.py b/es/separable_natural_es.py index 7dc8b61..38f9aee 100644 --- a/es/separable_natural_es.py +++ b/es/separable_natural_es.py @@ -1,14 +1,39 @@ +import glob import multiprocessing as mp import numpy as np +import pickle from . import lib +def load_checkpoint(): + filenames = glob.glob('checkpoint-*.pkl') + if filenames: + most_recent_checkpoint = max(filenames, key=lambda x: int(x.split('-')[1].split('.')[0])) + with open(most_recent_checkpoint, 'rb') as f: + most_recent_state = pickle.load(f) + return most_recent_state + else: + return None + + +def create_checkpoint(locals_dict, whitelist, label): + state = extract_keys_from_dict(locals_dict, whitelist) + with open('checkpoint-{}.pkl'.format(label), 'wb') as f: + pickle.dump(state, f) + + +def extract_keys_from_dict(d, whitelist, blacklist=[]): + return {key: value for key, value in d.items() if key in whitelist and key not in blacklist} + + def optimize(func, mu, sigma, learning_rate_mu=None, learning_rate_sigma=None, population_size=None, max_iter=2000, fitness_shaping=True, mirrored_sampling=True, record_history=False, rng=None, - parallel_threads=None): + parallel_threads=None, + checkpoint_interval=None, + load_existing_checkpoint=False): """ Evolution strategies using the natural gradient of multinormal search distributions in natural coordinates. Does not consider covariances between parameters. @@ -32,12 +57,28 @@ def optimize(func, mu, sigma, elif isinstance(rng, int): rng = np.random.RandomState(seed=rng) + mu = mu.copy() + sigma = sigma.copy() generation = 0 history_mu = [] history_sigma = [] history_pop = [] history_fitness = [] + mutable_locals = ['rng', 'mu', 'sigma', 'generation', 'history_mu', 'history_sigma', 'history_pop', 'history_fitness'] + + if load_existing_checkpoint: + state = load_checkpoint() + if state: + rng = state['rng'] + mu = state['mu'] + sigma = state['sigma'] + generation = state['generation'] + 1 + history_mu = state['history_mu'] + history_sigma = state['history_sigma'] + history_pop = state['history_pop'] + history_fitness = state['history_fitness'] + while True: s = rng.normal(0, 1, size=(population_size, *np.shape(mu))) z = mu + sigma * s @@ -76,15 +117,13 @@ def optimize(func, mu, sigma, history_pop.append(z.copy()) history_fitness.append(fitness.copy()) + if checkpoint_interval is not None and generation % checkpoint_interval == 0: + create_checkpoint(locals(), mutable_locals, generation) + generation += 1 # exit if max iterations reached if generation > max_iter or np.all(sigma < 1e-10): break - return {'mu': mu, - 'sigma': sigma, - 'history_mu': history_mu, - 'history_sigma': history_sigma, - 'history_fitness': history_fitness, - 'history_pop': history_pop} + return lib.create_results_dict(mu, sigma, history_mu, history_sigma, history_fitness, history_pop)