Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 34 additions & 8 deletions es/separable_natural_es.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,32 @@
import logging
import multiprocessing as mp
import numpy as np
from . import lib

logger = logging.getLogger(__name__)


def optimize(func, mu, sigma,
learning_rate_mu=None, learning_rate_sigma=None, population_size=None,
max_iter=2000,
fitness_shaping=True, mirrored_sampling=True, record_history=False,
rng=None,
parallel_threads=None):
parallel_threads=None,
verbosity=logging.WARNING):
"""
Evolution strategies using the natural gradient of multinormal search distributions in natural coordinates.
Does not consider covariances between parameters.
See Wierstra et al. (2014). Natural evolution strategies. Journal of Machine Learning Research, 15(1), 949-980.
"""

logger_ch = logging.StreamHandler()
logger_fh = logging.FileHandler('snes.log', 'w')
logger_ch.setFormatter(logging.Formatter('%(asctime)s %(levelname)s:%(name)s %(message)s'))
logger_fh.setFormatter(logging.Formatter('%(asctime)s %(levelname)s:%(name)s %(message)s'))
logger.setLevel(verbosity)
logger.addHandler(logger_ch)
logger.addHandler(logger_fh)

if not isinstance(mu, np.ndarray):
raise TypeError('mu needs to be of type np.ndarray')
if not isinstance(sigma, np.ndarray):
Expand All @@ -37,6 +50,11 @@ def optimize(func, mu, sigma,
history_pop = []
history_fitness = []

n_total_individuals = population_size * (1 + int(mirrored_sampling))
logger.info('starting evolution with {} individuals per generation on {} threads'.format(n_total_individuals, parallel_threads))
if (parallel_threads is None and n_total_individuals > 1) or (parallel_threads is not None and n_total_individuals > parallel_threads):
logger.warning('more individuals than parallel threads. expect long runtime')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"expect long runtime" sounds a bit casual. What about

"more individuals ({}) than parallel threads ({}). consider increasing the number of parallel threads.".format(n_total_individuals, parallel_threads)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it a convention to use only lowercase in logging messages?


while True:
s = rng.normal(0, 1, size=(population_size, *np.shape(mu)))
z = mu + sigma * s
Expand All @@ -45,13 +63,13 @@ def optimize(func, mu, sigma,
z = np.vstack([z, mu - sigma * s])
s = np.vstack([s, -s])

generations_list = [generation] * len(z)
individual_list = range(len(z))
if parallel_threads is None:
fitness = np.fromiter((func(zi) for zi in z), np.float)
fitness = np.fromiter((func(zi, gi, ii) for zi, gi, ii in zip(z, generations_list, individual_list)), np.float)
else:
pool = mp.Pool(processes=parallel_threads)
fitness = np.fromiter(pool.map(func, z), np.float)
pool.close()
pool.join()
with mp.Pool(processes=parallel_threads) as pool:
fitness = np.fromiter(pool.starmap(func, zip(z, generations_list, individual_list)), np.float)

ni = np.logical_not(np.isnan(fitness))
z = z[ni]
Expand All @@ -69,6 +87,11 @@ def optimize(func, mu, sigma,
mu += learning_rate_mu * sigma * np.dot(utility, s)
sigma *= np.exp(learning_rate_sigma / 2. * np.dot(utility, s ** 2 - 1))

logger.info('generation {}, average fitness {}'.format(generation, np.mean(fitness)))
logger.debug('fitness {}'.format(fitness))
logger.debug('mu {}'.format(mu))
logger.debug('sigma {}'.format(sigma))

if record_history:
history_mu.append(mu.copy())
history_sigma.append(sigma.copy())
Expand All @@ -78,7 +101,11 @@ def optimize(func, mu, sigma,
generation += 1

# exit if max iterations reached
if generation > max_iter or np.all(sigma < 1e-10):
if generation > max_iter:
logger.info('maximum number of iterations reached - exiting')
break
elif np.all(sigma < 1e-10):
logger.info('convergence of sigma detected - exiting')
break

return {'mu': mu,
Expand All @@ -87,4 +114,3 @@ def optimize(func, mu, sigma,
'history_sigma': history_sigma,
'history_fitness': history_fitness,
'history_pop': history_pop}