From 3b3a61ceaf29bde851cd058620d024193fb9f19a Mon Sep 17 00:00:00 2001 From: Jakob Jordan Date: Wed, 16 May 2018 16:37:10 +0200 Subject: [PATCH] Provide generation and individual id to objective in snes --- es/separable_natural_es.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/es/separable_natural_es.py b/es/separable_natural_es.py index a1bd3de..7dc8b61 100644 --- a/es/separable_natural_es.py +++ b/es/separable_natural_es.py @@ -2,6 +2,7 @@ import numpy as np from . import lib + def optimize(func, mu, sigma, learning_rate_mu=None, learning_rate_sigma=None, population_size=None, max_iter=2000, @@ -45,13 +46,13 @@ def optimize(func, mu, sigma, z = np.vstack([z, mu - sigma * s]) s = np.vstack([s, -s]) + generations_list = [generation] * len(z) + individual_list = range(len(z)) if parallel_threads is None: - fitness = np.fromiter((func(zi) for zi in z), np.float) + fitness = np.fromiter((func(zi, gi, ii) for zi, gi, ii in zip(z, generations_list, individual_list)), np.float) else: - pool = mp.Pool(processes=parallel_threads) - fitness = np.fromiter(pool.map(func, z), np.float) - pool.close() - pool.join() + with mp.Pool(processes=parallel_threads) as pool: + fitness = np.fromiter(pool.starmap(func, zip(z, generations_list, individual_list)), np.float) ni = np.logical_not(np.isnan(fitness)) z = z[ni] @@ -87,4 +88,3 @@ def optimize(func, mu, sigma, 'history_sigma': history_sigma, 'history_fitness': history_fitness, 'history_pop': history_pop} - \ No newline at end of file