Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions es/lib.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
import numpy as np
import pickle


def create_results_dict(mu, sigma, history_mu, history_sigma, history_fitness, history_pop):
return {'mu': mu,
'sigma': sigma,
'history_mu': history_mu,
'history_sigma': history_sigma,
'history_fitness': history_fitness,
'history_pop': history_pop}


def default_population_size(dimensions):
Expand Down
65 changes: 52 additions & 13 deletions es/separable_natural_es.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,39 @@
import glob
import multiprocessing as mp
import numpy as np
import pickle
from . import lib


def load_checkpoint():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about enabling the user to specify a specific checkpoint, not only the most recent one? You could implement this by allowing load_existing_checkpoint to be specified as True or -1 (for the most recent one) or as an integer specifying the generation.

filenames = glob.glob('checkpoint-*.pkl')
if filenames:
most_recent_checkpoint = max(filenames, key=lambda x: int(x.split('-')[1].split('.')[0]))
with open(most_recent_checkpoint, 'rb') as f:
most_recent_state = pickle.load(f)
return most_recent_state
else:
return None


def create_checkpoint(locals_dict, whitelist, label):
state = extract_keys_from_dict(locals_dict, whitelist)
with open('checkpoint-{}.pkl'.format(label), 'wb') as f:
pickle.dump(state, f)


def extract_keys_from_dict(d, whitelist, blacklist=[]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It appears to me that this function should be located in lib.py.

return {key: value for key, value in d.items() if key in whitelist and key not in blacklist}


def optimize(func, mu, sigma,
learning_rate_mu=None, learning_rate_sigma=None, population_size=None,
max_iter=2000,
fitness_shaping=True, mirrored_sampling=True, record_history=False,
rng=None,
parallel_threads=None):
parallel_threads=None,
checkpoint_interval=None,
load_existing_checkpoint=False):
"""
Evolution strategies using the natural gradient of multinormal search distributions in natural coordinates.
Does not consider covariances between parameters.
Expand All @@ -31,12 +57,28 @@ def optimize(func, mu, sigma,
elif isinstance(rng, int):
rng = np.random.RandomState(seed=rng)

mu = mu.copy()
sigma = sigma.copy()
generation = 0
history_mu = []
history_sigma = []
history_pop = []
history_fitness = []

mutable_locals = ['rng', 'mu', 'sigma', 'generation', 'history_mu', 'history_sigma', 'history_pop', 'history_fitness']

if load_existing_checkpoint:
state = load_checkpoint()
if state:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If state is None here, i.e. if there were no checkpoints that can be loaded, I suggest to raise an error. Otherwise checkpoint loading fails silently.

rng = state['rng']
mu = state['mu']
sigma = state['sigma']
generation = state['generation'] + 1
history_mu = state['history_mu']
history_sigma = state['history_sigma']
history_pop = state['history_pop']
history_fitness = state['history_fitness']

while True:
s = rng.normal(0, 1, size=(population_size, *np.shape(mu)))
z = mu + sigma * s
Expand All @@ -45,13 +87,13 @@ def optimize(func, mu, sigma,
z = np.vstack([z, mu - sigma * s])
s = np.vstack([s, -s])

generations_list = [generation] * len(z)
individual_list = range(len(z))
if parallel_threads is None:
fitness = np.fromiter((func(zi) for zi in z), np.float)
fitness = np.fromiter((func(zi, gi, ii) for zi, gi, ii in zip(z, generations_list, individual_list)), np.float)
else:
pool = mp.Pool(processes=parallel_threads)
fitness = np.fromiter(pool.map(func, z), np.float)
pool.close()
pool.join()
with mp.Pool(processes=parallel_threads) as pool:
fitness = np.fromiter(pool.starmap(func, zip(z, generations_list, individual_list)), np.float)

ni = np.logical_not(np.isnan(fitness))
z = z[ni]
Expand All @@ -75,16 +117,13 @@ def optimize(func, mu, sigma,
history_pop.append(z.copy())
history_fitness.append(fitness.copy())

if checkpoint_interval is not None and generation % checkpoint_interval == 0:
create_checkpoint(locals(), mutable_locals, generation)

generation += 1

# exit if max iterations reached
if generation > max_iter or np.all(sigma < 1e-10):
break

return {'mu': mu,
'sigma': sigma,
'history_mu': history_mu,
'history_sigma': history_sigma,
'history_fitness': history_fitness,
'history_pop': history_pop}

return lib.create_results_dict(mu, sigma, history_mu, history_sigma, history_fitness, history_pop)