From 5204053813d70c6e7a1350e7c011058dd262cdca Mon Sep 17 00:00:00 2001 From: Maximilian Schmidt Date: Thu, 17 May 2018 10:20:52 +0900 Subject: [PATCH] Implement option to pass a wrapper around the fitness function to the optimize function to enable flexible initialization of the fitness at each generation --- es/separable_natural_es.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/es/separable_natural_es.py b/es/separable_natural_es.py index a1bd3de..dc6b0dd 100644 --- a/es/separable_natural_es.py +++ b/es/separable_natural_es.py @@ -2,7 +2,9 @@ import numpy as np from . import lib -def optimize(func, mu, sigma, +def optimize(func, + mu, sigma, + func_wrapper=None, learning_rate_mu=None, learning_rate_sigma=None, population_size=None, max_iter=2000, fitness_shaping=True, mirrored_sampling=True, record_history=False, @@ -45,11 +47,16 @@ def optimize(func, mu, sigma, z = np.vstack([z, mu - sigma * s]) s = np.vstack([s, -s]) - if parallel_threads is None: - fitness = np.fromiter((func(zi) for zi in z), np.float) + if func_wrapper is not None: + func_i = func_wrapper(func) else: + func_i = func + + if parallel_threads is None: + fitness = np.fromiter((func_i(zi) for zi in z), np.float) + elif isinstance(parallel_threads, int): pool = mp.Pool(processes=parallel_threads) - fitness = np.fromiter(pool.map(func, z), np.float) + fitness = np.fromiter(pool.map(func_i, z), np.float) pool.close() pool.join()