diff --git a/README.md b/README.md index 4300c57..c22c498 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,7 @@ -# pyhms +# pyHMS + +pyHMS Logo + ![GitHub Test Badge][1] [![codecov][2]](https://codecov.io/gh/agh-a2s/pyhms) [![Documentation Status][3]](https://pyhms.readthedocs.io/en/latest/?badge=latest) [![pypi.org][4]][5] [![versions][6]][7] ![license][8] [1]: https://github.com/agh-a2s/pyhms/actions/workflows/pytest.yml/badge.svg "GitHub CI Badge" @@ -10,7 +13,7 @@ [7]: https://github.com/agh-a2s/pyhms [8]: https://img.shields.io/github/license/agh-a2s/pyhms -`pyhms` is a Python implementation of Hierarchic Memetic Strategy (HMS). +`pyHMS` is a Python implementation of Hierarchic Memetic Strategy (HMS). The Hierarchic Memetic Strategy is a stochastic global optimizer designed to tackle highly multimodal problems. It is a composite global optimization strategy consisting of a multi-population evolutionary strategy and some auxiliary methods. The HMS makes use of a dynamically-evolving data structure that provides an organization among the component populations. It is a tree with a fixed maximal height and variable internal node degree. Each component population is governed by a particular optimization engine. This package provides a simple python implementation. diff --git a/docs/_static/images/pyhms.png b/docs/_static/images/pyhms.png new file mode 100644 index 0000000..69e3bcf Binary files /dev/null and b/docs/_static/images/pyhms.png differ diff --git a/docs/custom_demes.rst b/docs/custom_demes.rst new file mode 100644 index 0000000..9015343 --- /dev/null +++ b/docs/custom_demes.rst @@ -0,0 +1,165 @@ +Adding Custom Demes to pyHMS +============================ + +This guide explains how to create your own custom deme implementations for pyHMS. + +Overview +-------- + +pyHMS allows you to extend the system with your own custom deme implementations. To create a custom deme, you need to: + +1. Define a new config class that inherits from ``BaseLevelConfig`` +2. Create a new deme class that inherits from ``AbstractDeme`` +3. Register your custom deme by passing a ``config_class_to_deme_class`` mapping to the ``hms`` function + +Step 1: Define Your Config Class +-------------------------------- + +Start by creating a config class that inherits from ``BaseLevelConfig``. This class should: + +- Accept a ``problem`` and a stop condition (``lsc``) as required parameters +- Include any additional parameters your deme implementation needs +- Call the parent class's ``__init__`` method + +.. code-block:: python + + from pyhms.config import BaseLevelConfig + from pyhms.core.problem import Problem + from pyhms.stop_conditions import LocalStopCondition, UniversalStopCondition + + class RandomSearchConfig(BaseLevelConfig): + def __init__( + self, + problem: Problem, + lsc: LocalStopCondition | UniversalStopCondition, + pop_size: int, + ) -> None: + super().__init__(problem, lsc) + self.pop_size = pop_size + +Step 2: Create Your Deme Class +------------------------------ + +Next, create a deme class that inherits from ``AbstractDeme``. This class must implement the required interface: + +.. code-block:: python + + import numpy as np + from pyhms.core.individual import Individual + from pyhms.demes.abstract_deme import AbstractDeme, DemeInitArgs + + class RandomSearchDeme(AbstractDeme): + def __init__( + self, + deme_init_args: DemeInitArgs, + ) -> None: + super().__init__(deme_init_args) + config: RandomSearchConfig = deme_init_args.config # type: ignore[assignment] + self._pop_size = config.pop_size + self.lower_bounds = config.bounds[:, 0] + self.upper_bounds = config.bounds[:, 1] + self.rng = np.random.RandomState(deme_init_args.random_seed) + self.run() + + def run(self) -> None: + genomes = np.random.uniform( + self.lower_bounds, + self.upper_bounds, + size=(self._pop_size, len(self.lower_bounds)) + ) + population = [Individual(genome, problem=self._problem) for genome in genomes] + Individual.evaluate_population(population) + self._history.append([population]) + + def run_metaepoch(self, tree) -> None: + # This method is called in each meta-epoch + self.run() + + # Check if stopping conditions are met + if (gsc_value := tree._gsc(tree)) or self._lsc(self): + self._active = False + message = "Random Search Deme finished due to GSC" if gsc_value else "Random Search Deme finished due to LSC" + self.log(message) + return + +Understanding DemeInitArgs +-------------------------- + +When implementing a custom deme, you'll receive a ``DemeInitArgs`` object in the constructor. This dataclass contains all the necessary initialization parameters for your deme: + +.. code-block:: python + + @dataclass + class DemeInitArgs: + id: str + level: int + config: BaseLevelConfig + logger: FilteringBoundLogger + started_at: int = 0 + sprout_seed: Individual | None = None + random_seed: int | None = None + parent_deme: AbstractDeme | None = None + +Understanding these fields: + +- ``id``: A unique string identifier for your deme +- ``level``: The hierarchical level in the HMS tree (starts at 0 for root) +- ``config``: Your custom configuration class instance that inherits from ``BaseLevelConfig`` +- ``logger``: A structured logger for outputting debug information +- ``started_at``: The metaepoch number when this deme was created +- ``sprout_seed``: For non-root demes, this is the first individual that sprouted this deme +- ``random_seed``: A seed for random number generation to ensure reproducibility +- ``parent_deme``: Reference to the parent deme that sprouted this deme (None for root demes) + +In your custom deme implementation, you'll typically: + +1. Pass the ``DemeInitArgs`` object to the parent constructor +2. Cast the ``config`` field to your specific config class type +3. Access the configuration parameters you need +4. Use the provided random seed for any randomized operations + +Step 3: Register and Use Your Custom Deme +----------------------------------------- + +Finally, register your custom deme by creating a mapping from your config class to your deme class and passing it to the ``hms`` function: + +.. code-block:: python + + from pyhms import hms + from pyhms.stop_conditions import DontStop, MetaepochLimit + + # Create your deme configuration + random_search_config = RandomSearchConfig( + problem=your_problem, + lsc=DontStop(), + pop_size=100 + ) + + # Define the mapping from config class to deme class + config_class_to_deme_class = { + RandomSearchConfig: RandomSearchDeme + } + + # Use your custom deme in pyHMS + result = hms( + level_config=[random_search_config], + gsc=MetaepochLimit(10), + sprout_cond=your_sprout_condition, + config_class_to_deme_class=config_class_to_deme_class + ) + +Important AbstractDeme Properties and Methods +--------------------------------------------- + +When implementing your custom deme, you can use the following properties and methods from the ``AbstractDeme`` base class: + +- ``self._problem``: The optimization problem +- ``self._bounds``: The bounds of the search space +- ``self._active``: A flag indicating if the deme is active +- ``self._history``: History of populations (list of lists of individuals) +- ``self.log(message)``: Log a message with additional meta information +- ``self.centroid``: Compute the centroid of the current population +- ``self.best_individual``: Get the best individual found by the deme +- ``self.current_population``: Get the current population + +The most important method you must implement is ``run_metaepoch(self, tree)``, which is called in each meta-epoch of the HMS algorithm. diff --git a/docs/index.rst b/docs/index.rst index ffd51a3..2c98881 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,9 +1,14 @@ .. include:: ../README.rst -Welcome to pyhms's documentation! +Welcome to pyHMS's documentation! =================================== -**pyhms** is a Python implementation of Hierarchic Memetic Strategy (HMS). +.. image:: _static/images/pyhms.png + :width: 200px + :alt: pyHMS Logo + :align: center + +**pyHMS** is a Python implementation of Hierarchic Memetic Strategy (HMS). The Hierarchic Memetic Strategy is a stochastic global optimizer designed to tackle highly multimodal problems. It is a composite global optimization strategy consisting of a multi-population evolutionary strategy and some auxiliary methods. The HMS makes use of a dynamically-evolving data structure that provides an organization among the component populations. It is a tree with a fixed maximal height and variable internal node degree. Each component population is governed by a particular optimization engine. This package provides a simple python implementation. @@ -23,6 +28,7 @@ Contents algorithm usage inspecting - stop sprout + custom_demes + stop problem diff --git a/docs/inspecting.rst b/docs/inspecting.rst index 088b8e4..2fee400 100644 --- a/docs/inspecting.rst +++ b/docs/inspecting.rst @@ -6,7 +6,7 @@ Visualizing and inspecting the results of evolutionary strategies (ES) are cruci Visualization helps in understanding how solutions evolve over generations. Through visual inspection, one can observe if the population is converging towards a global optimum or if it is stuck in local optima. Visualization of the population distribution over time can also highlight issues with diversity, indicating whether the evolutionary strategy is exploring the solution space adequately. -`pyhms` provides different methods for `DemeTree` object that enable inspecting results. +`pyHMS` provides different methods for `DemeTree` object that enable inspecting results. Let's consider Sphere function as an example for `N=5`. .. code-block:: python diff --git a/docs/problem.rst b/docs/problem.rst index 8efa5e9..56fca38 100644 --- a/docs/problem.rst +++ b/docs/problem.rst @@ -1,12 +1,53 @@ Problem ======= -`pyhms` provides different `Problem` wrappers. These wrappers are used to wrap the problem and provide additional functionality such as counting the number of evaluations (`EvalCountingProblem`, `EvalCutoffProblem`), or stopping the evaluation when a certain precision is reached (`PrecisionCutoffProblem`). +The `Problem` class hierarchy in `pyHMS` provides the foundation for defining optimization problems and wrapping them with additional functionality. + +Base Classes +------------ + +.. autoclass:: pyhms.core.problem.Problem + :members: + +.. autoclass:: pyhms.core.problem.FunctionProblem + :members: + +.. autoclass:: pyhms.core.problem.ProblemWrapper + :members: + +Problem Wrappers +---------------- + +`pyHMS` provides various `Problem` wrappers. These wrappers enhance problem instances with additional functionality without modifying their core behavior. Common use cases include: + +1. Monitoring optimization performance (counting evaluations, measuring time) +2. Enforcing constraints (maximum evaluations, precision thresholds) +3. Collecting statistics for analysis + +Available Wrappers +^^^^^^^^^^^^^^^^^^ .. autoclass:: pyhms.core.problem.EvalCountingProblem + :members: + + A wrapper that counts the number of function evaluations performed. .. autoclass:: pyhms.core.problem.EvalCutoffProblem + :members: + + A wrapper that stops evaluations after a specified limit is reached, returning infinity (or negative infinity for maximization problems). .. autoclass:: pyhms.core.problem.PrecisionCutoffProblem + :members: + + A wrapper that tracks when the solution reaches a specified precision threshold relative to the known global optimum. .. autoclass:: pyhms.core.problem.StatsGatheringProblem + :members: + + A wrapper that collects statistics about evaluation times, useful for performance analysis. + +Helper Functions +---------------- + +.. autofunction:: pyhms.core.problem.get_function_problem diff --git a/docs/usage.rst b/docs/usage.rst index 945f44c..2636f46 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -40,14 +40,14 @@ The output of the function is a OptimizeResult object: @dataclass class OptimizeResult: - x: np.ndarray - nfev: int - fun: float - nit: int + x: np.ndarray # Best solution found + nfev: int # Number of function evaluations + fun: float # Function value at the best solution + nit: int # Number of iterations (metaepochs) -Usage ------ +Detailed Usage +-------------- Let's begin by defining a problem that we want to solve. We will use the following example: @@ -74,7 +74,13 @@ To use HMS we need to define global stop condition, in this case we want to run from pyhms import MetaepochLimit global_stop_condition = MetaepochLimit(limit=10) -Now we need to decide what should be the height of our tree (maximum number of levels) and what optimization algorithms to run on each level. We will use the following configuration: +Now we need to configure the structure of our HMS tree by defining the optimization algorithms for each level. Each level configuration specifies the following: + +1. The optimization algorithm to use (`EALevelConfig` which can run multiple different GAs) +2. The number of iterations per metaepoch (`generations`) +3. The problem to solve (can be different for each level e.g. less accurate for higher levels) +4. Population size and other algorithm-specific parameters +5. Local stop condition (`lsc`) .. code-block:: python @@ -82,32 +88,36 @@ Now we need to decide what should be the height of our tree (maximum number of l config = [ EALevelConfig( - ea_class=SEA, - generations=2, - problem=square_problem, - pop_size=20, - mutation_std=1.0, - lsc=DontStop(), + ea_class=SEA, # Use Simple Evolutionary Algorithm (GA) + generations=2, # Number of generations per metaepoch + problem=square_problem, # The problem to solve (problems can be different for each level) + pop_size=20, # Population size + mutation_std=1.0, # Standard deviation for mutation + lsc=DontStop(), # Local stop condition (never stop) ), EALevelConfig( ea_class=SEA, - generations=4, + generations=4, # More generations for deeper exploration problem=square_problem, - pop_size=10, - mutation_std=0.25, - sample_std_dev=1.0, + pop_size=10, # Smaller population size at lower levels + mutation_std=0.25, # Smaller mutations for local refinement + sample_std_dev=1.0, # Standard deviation for sampling around parent lsc=DontStop(), ), ] -Next step is to define sprout condition for our tree. We will use Nearest Better Clustering (NBC) sprout condition. +The HMS algorithm creates a tree-like structure where demes (populations) at higher levels perform broad exploration, while demes at lower levels refine promising solutions. The configuration above defines two levels in our tree. + +Next, we need to define a sprouting condition that determines when and where to create new demes at lower levels. We'll use Nearest Better Clustering (NBC) sprouting: .. code-block:: python from pyhms import get_NBC_sprout sprout_condition = get_NBC_sprout(level_limit=4) -Finally we can run the algorithm: +The NBC sprouting condition identifies promising points in the search space by clustering solutions based on their fitness and proximity. See :doc:`sprout` for more details on sprouting mechanisms. + +Finally, we can run the algorithm: .. code-block:: python diff --git a/examples/landscape_approximation.ipynb b/examples/landscape_approximation.ipynb index 91a14cf..3c37cf3 100644 --- a/examples/landscape_approximation.ipynb +++ b/examples/landscape_approximation.ipynb @@ -58,11 +58,8 @@ " \n", " Parameters:\n", " s (float): Function value to transform.\n", - " \n", - " Returns:\n", - " float: Transformed value max(2s-1, 0).\n", " \"\"\"\n", - " return max(2 * s - 1, 0)\n", + " return max(10 * s - 8, 0)\n", "\n", "\n", "def c_shaped_plateau(x):\n", @@ -123,7 +120,7 @@ " TreeConfig,\n", " DemeTree,\n", ")\n", - "from pyhms.sprout import get_NBC_sprout\n", + "from pyhms.sprout import get_simple_sprout\n", "from pyhms.demes.single_pop_eas.sea import SEA\n", "\n", "N = 2\n", @@ -150,9 +147,9 @@ " ),\n", "]\n", "\n", - "global_stop_condition = SingularProblemEvalLimitReached(10000)\n", + "global_stop_condition = SingularProblemEvalLimitReached(1000)\n", "\n", - "sprout_condition = get_NBC_sprout()\n", + "sprout_condition = get_simple_sprout(far_enough=2.0)\n", "config = TreeConfig(\n", " tree_config, global_stop_condition, sprout_condition, options={\"random_seed\": 1}\n", ")\n", @@ -160,6 +157,15 @@ "hms_tree.run()" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "hms_tree.tree_diagram()" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -182,14 +188,14 @@ "\n", "problem = FunctionProblem(c_shaped_plateau, maximize=False, bounds=bounds)\n", "\n", - "mwea = MWEA.create(problem=problem, mutation_std=0.2, p_mutation=0.35)\n", + "mwea = MWEA.create(problem=problem, mutation_std=0.5, p_mutation=0.5)\n", "\n", "la = LandscapeApproximator(\n", " hms_tree=hms_tree,\n", " merge_condition=HillValleyMergeCondition(\n", " FunctionProblem(c_shaped_plateau, maximize=False, bounds=bounds), k=10\n", " ),\n", - " local_basin_epochs=10,\n", + " local_basin_epochs=15,\n", " mwea=mwea,\n", ")\n", "la.fit()\n", @@ -202,7 +208,12 @@ "metadata": {}, "outputs": [], "source": [ - "la.plot_plateau_contour(filepath=\"kriging_contour.eps\", threshold=0.3)" + "la.plot_plateau_contour(\n", + " filepath=\"kriging_contour_with_gt.eps\",\n", + " show_true_plateau=True,\n", + " number_of_points_per_dim=250,\n", + " threshold=0.0,\n", + ")" ] }, { diff --git a/pyhms/cluster/kriging.py b/pyhms/cluster/kriging.py index c26ce64..9356afd 100644 --- a/pyhms/cluster/kriging.py +++ b/pyhms/cluster/kriging.py @@ -77,6 +77,7 @@ def plot_plateau_contour( threshold: float | None = None, number_of_points_per_dim: int = 100, filepath: str | None = None, + show_true_plateau: bool = False, ) -> None: if self.population is None or self.model is None: raise ValueError("Model must be fitted before plotting") @@ -87,7 +88,7 @@ def plot_plateau_contour( z, _ = self.model.execute("grid", x, y) if threshold is None: - threshold = np.median(self.population.fitnesses) + threshold = np.quantile(z, 0.1) fig = go.Figure( data=go.Contour( @@ -104,9 +105,34 @@ def plot_plateau_contour( line=dict(width=3, color="blue"), colorscale=[[0, "rgba(0,0,0,0)"], [1, "rgba(0,0,0,0)"]], showscale=False, + name="Kriging", ) ) + if show_true_plateau: + z_true = np.zeros((len(x), len(y))) + for i, xi in enumerate(x): + for j, yj in enumerate(y): + z_true[j, i] = self.population.problem.evaluate(np.array([xi, yj])) + fig.add_trace( + go.Contour( + x=x, + y=y, + z=z_true, + contours=dict( + start=threshold, + end=threshold, + size=0.1, + showlabels=True, + labelfont=dict(size=14, color="black"), + ), + line=dict(width=3, color="orange"), + colorscale=[[0, "rgba(0,0,0,0)"], [1, "rgba(0,0,0,0)"]], + showscale=False, + name="Ground truth", + ) + ) + fig.update_layout( xaxis_title="x", yaxis_title="y", @@ -115,6 +141,12 @@ def plot_plateau_contour( template="plotly_white", font=dict(size=16), showlegend=True, + legend=dict( + yanchor="top", + y=0.99, + xanchor="right", + x=0.99, + ), xaxis_range=[bounds[0][0], bounds[0][1]], yaxis_range=[bounds[1][0], bounds[1][1]], margin=dict(l=80, r=80, t=80, b=80), @@ -122,6 +154,9 @@ def plot_plateau_contour( yaxis=dict(automargin=True, ticklabelposition="outside"), ) + for trace in fig.data: + trace.showlegend = True + fig.show() if filepath is not None: diff --git a/pyhms/cluster/landscape_approximator.py b/pyhms/cluster/landscape_approximator.py index aab4873..3bb31e8 100644 --- a/pyhms/cluster/landscape_approximator.py +++ b/pyhms/cluster/landscape_approximator.py @@ -45,5 +45,16 @@ def predict(self, x: np.ndarray) -> np.ndarray: def plot(self, filepath: str | None = None) -> None: self.kriging.plot(filepath=filepath) - def plot_plateau_contour(self, threshold: float | None = None, filepath: str | None = None) -> None: - self.kriging.plot_plateau_contour(threshold=threshold, filepath=filepath) + def plot_plateau_contour( + self, + threshold: float | None = None, + filepath: str | None = None, + number_of_points_per_dim: int = 100, + show_true_plateau: bool = False, + ) -> None: + self.kriging.plot_plateau_contour( + threshold=threshold, + filepath=filepath, + show_true_plateau=show_true_plateau, + number_of_points_per_dim=number_of_points_per_dim, + ) diff --git a/pyhms/config.py b/pyhms/config.py index 12c62eb..2140da3 100644 --- a/pyhms/config.py +++ b/pyhms/config.py @@ -152,8 +152,10 @@ def __init__( gsc: GlobalStopCondition | UniversalStopCondition, sprout_mechanism, options: Options = DEFAULT_OPTIONS, + config_class_to_deme_class: dict[type[BaseLevelConfig], "type[AbstractDeme]"] = {}, # type: ignore # noqa: F821 ) -> None: self.levels = levels self.gsc = gsc self.sprout_mechanism = sprout_mechanism self.options = options + self.config_class_to_deme_class = config_class_to_deme_class diff --git a/pyhms/demes/abstract_deme.py b/pyhms/demes/abstract_deme.py index 1d6cb37..80e89d0 100644 --- a/pyhms/demes/abstract_deme.py +++ b/pyhms/demes/abstract_deme.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from dataclasses import dataclass import numpy as np from pyhms.config import BaseLevelConfig @@ -15,25 +16,33 @@ def compute_centroid(population: list[Individual]) -> np.ndarray | None: return np.mean([ind.genome for ind in population], axis=0) +@dataclass +class DemeInitArgs: + id: str + level: int + config: BaseLevelConfig + logger: FilteringBoundLogger + started_at: int = 0 + sprout_seed: Individual | None = None + random_seed: int | None = None + # Use forward reference string to avoid circular dependency + parent_deme: "AbstractDeme | None" = None + + class AbstractDeme(ABC): def __init__( self, - id: str, - level: int, - config: BaseLevelConfig, - logger: FilteringBoundLogger, - started_at: int = 0, - sprout_seed: Individual = None, + deme_init_args: DemeInitArgs, ) -> None: super().__init__() - self._id = id - self._started_at = started_at - self._sprout_seed = sprout_seed - self._level = level - self._config: BaseLevelConfig = config - self._lsc: LocalStopCondition | UniversalStopCondition = config.lsc - self._problem: EvalCountingProblem = EvalCountingProblem(config.problem) - self._bounds: np.ndarray = config.bounds + self._id = deme_init_args.id + self._started_at = deme_init_args.started_at + self._sprout_seed = deme_init_args.sprout_seed + self._level = deme_init_args.level + self._config: BaseLevelConfig = deme_init_args.config + self._lsc: LocalStopCondition | UniversalStopCondition = deme_init_args.config.lsc + self._problem: EvalCountingProblem = EvalCountingProblem(deme_init_args.config.problem) + self._bounds: np.ndarray = deme_init_args.config.bounds self._active: bool = True self._centroid: np.ndarray | None = None # History of populations is a nested list, where each element is a list of individuals. @@ -41,7 +50,7 @@ def __init__( # and for some algorithms (e.g. CMA-ES) HMS can run multiple generations during one metaepoch. self._history: list[list[list[Individual]]] = [] self._children: list[AbstractDeme] = [] - self._logger: FilteringBoundLogger = logger + self._logger: FilteringBoundLogger = deme_init_args.logger # Additional low-level options self._hibernating: bool = False diff --git a/pyhms/demes/cma_deme.py b/pyhms/demes/cma_deme.py index 67ba748..bb2125d 100644 --- a/pyhms/demes/cma_deme.py +++ b/pyhms/demes/cma_deme.py @@ -1,43 +1,37 @@ import numpy as np from cma import CMAEvolutionStrategy from pyhms.core.individual import Individual -from structlog.typing import FilteringBoundLogger from ..config import CMALevelConfig from ..utils.covariance_estimate import get_initial_sigma0, get_initial_stds -from .abstract_deme import AbstractDeme +from .abstract_deme import AbstractDeme, DemeInitArgs class CMADeme(AbstractDeme): def __init__( self, - id: str, - level: int, - config: CMALevelConfig, - logger: FilteringBoundLogger, - x0: Individual, - started_at: int = 0, - random_seed: int = None, - parent_deme: AbstractDeme | None = None, + deme_init_args: DemeInitArgs, ) -> None: - super().__init__(id, level, config, logger, started_at, x0) + super().__init__(deme_init_args) + config: CMALevelConfig = deme_init_args.config # type: ignore[assignment] self.generations = config.generations lb = [bound[0] for bound in config.bounds] ub = [bound[1] for bound in config.bounds] opts = {"bounds": [lb, ub], "verbose": -9} - if random_seed is not None: + if deme_init_args.random_seed is not None: opts["randn"] = np.random.randn - opts["seed"] = random_seed + self._started_at + opts["seed"] = deme_init_args.random_seed + self._started_at + x0 = deme_init_args.sprout_seed.genome if config.__dict__.get("set_stds"): - opts["CMA_stds"] = get_initial_stds(parent_deme, x0) + opts["CMA_stds"] = get_initial_stds(deme_init_args.parent_deme, deme_init_args.sprout_seed) # We recommend to use sigma0 = 1 in this case. sigma0 = 1.0 if config.sigma0 is None else config.sigma0 - self._cma_es = CMAEvolutionStrategy(x0.genome, sigma0, inopts=opts) + self._cma_es = CMAEvolutionStrategy(x0, sigma0, inopts=opts) elif config.sigma0: - self._cma_es = CMAEvolutionStrategy(x0.genome, config.sigma0, inopts=opts) + self._cma_es = CMAEvolutionStrategy(x0, config.sigma0, inopts=opts) else: - sigma0 = get_initial_sigma0(parent_deme, x0) - self._cma_es = CMAEvolutionStrategy(x0.genome, sigma0, inopts=opts) + sigma0 = get_initial_sigma0(deme_init_args.parent_deme, deme_init_args.sprout_seed) + self._cma_es = CMAEvolutionStrategy(x0, sigma0, inopts=opts) starting_pop = [Individual(solution, problem=self._problem) for solution in self._cma_es.ask()] Individual.evaluate_population(starting_pop) diff --git a/pyhms/demes/de_deme.py b/pyhms/demes/de_deme.py index 6718a0c..e4f1f0a 100644 --- a/pyhms/demes/de_deme.py +++ b/pyhms/demes/de_deme.py @@ -1,23 +1,18 @@ -from pyhms.config import DELevelConfig -from pyhms.demes.abstract_deme import AbstractDeme +from pyhms.demes.abstract_deme import AbstractDeme, DemeInitArgs from pyhms.demes.single_pop_eas.de import DE from pyhms.initializers import sample_normal, sample_uniform -from structlog.typing import FilteringBoundLogger +from ..config import DELevelConfig from ..core.individual import Individual class DEDeme(AbstractDeme): def __init__( self, - id: str, - level: int, - config: DELevelConfig, - logger: FilteringBoundLogger, - started_at: int = 0, - sprout_seed: Individual = None, + deme_init_args: DemeInitArgs, ) -> None: - super().__init__(id, level, config, logger, started_at, sprout_seed) + super().__init__(deme_init_args) + config: DELevelConfig = deme_init_args.config # type: ignore[assignment] self._pop_size = config.pop_size self._generations = config.generations self._sample_std_dev = config.sample_std_dev @@ -27,20 +22,20 @@ def __init__( f=config.scaling, ) - if sprout_seed is None: + if deme_init_args.sprout_seed is None: starting_pop = Individual.create_population( self._pop_size, initialize=sample_uniform(bounds=self._bounds), problem=self._problem, ) else: - x = sprout_seed.genome + x0 = deme_init_args.sprout_seed.genome starting_pop = Individual.create_population( self._pop_size - 1, - initialize=sample_normal(x, self._sample_std_dev, bounds=self._bounds), + initialize=sample_normal(x0, self._sample_std_dev, bounds=self._bounds), problem=self._problem, ) - seed_ind = Individual(x, problem=self._problem) + seed_ind = Individual(x0, problem=self._problem) starting_pop.append(seed_ind) Individual.evaluate_population(starting_pop) diff --git a/pyhms/demes/ea_deme.py b/pyhms/demes/ea_deme.py index 38b94d8..6b3727c 100644 --- a/pyhms/demes/ea_deme.py +++ b/pyhms/demes/ea_deme.py @@ -1,21 +1,16 @@ from pyhms.config import EALevelConfig from pyhms.core.individual import Individual -from pyhms.demes.abstract_deme import AbstractDeme +from pyhms.demes.abstract_deme import AbstractDeme, DemeInitArgs from pyhms.initializers import sample_normal, sample_uniform -from structlog.typing import FilteringBoundLogger class EADeme(AbstractDeme): def __init__( self, - id: str, - level: int, - config: EALevelConfig, - logger: FilteringBoundLogger, - started_at: int = 0, - sprout_seed: Individual = None, + deme_init_args: DemeInitArgs, ) -> None: - super().__init__(id, level, config, logger, started_at, sprout_seed) + super().__init__(deme_init_args) + config: EALevelConfig = deme_init_args.config # type: ignore[assignment] self._sample_std_dev = config.sample_std_dev self._pop_size = config.pop_size self._generations = config.generations @@ -24,20 +19,20 @@ def __init__( ea_params["problem"] = self._problem self._ea = config.ea_class.create(**ea_params) - if sprout_seed is None: + if deme_init_args.sprout_seed is None: starting_pop = Individual.create_population( self._pop_size, initialize=sample_uniform(bounds=self._bounds), problem=self._problem, ) else: - x = sprout_seed.genome + x0 = deme_init_args.sprout_seed.genome starting_pop = Individual.create_population( self._pop_size - 1, - initialize=sample_normal(x, self._sample_std_dev, bounds=self._bounds), + initialize=sample_normal(x0, self._sample_std_dev, bounds=self._bounds), problem=self._problem, ) - seed_ind = Individual(x, problem=self._problem) + seed_ind = Individual(x0, problem=self._problem) starting_pop.append(seed_ind) Individual.evaluate_population(starting_pop) diff --git a/pyhms/demes/initialize.py b/pyhms/demes/initialize.py index 366874a..9fbd0e7 100644 --- a/pyhms/demes/initialize.py +++ b/pyhms/demes/initialize.py @@ -11,7 +11,7 @@ from structlog.typing import FilteringBoundLogger from ..core.individual import Individual -from .abstract_deme import AbstractDeme +from .abstract_deme import AbstractDeme, DemeInitArgs from .cma_deme import CMADeme from .de_deme import DEDeme from .ea_deme import EADeme @@ -20,9 +20,15 @@ from .shade_deme import SHADEDeme from .sobol_deme import SobolDeme - -def init_root(config: BaseLevelConfig, logger: FilteringBoundLogger) -> AbstractDeme: - return init_from_config(config, "root", 0, 0, None, logger) +CONFIG_CLASS_TO_DEME_CLASS = { + DELevelConfig: DEDeme, + SHADELevelConfig: SHADEDeme, + EALevelConfig: EADeme, + CMALevelConfig: CMADeme, + LocalOptimizationConfig: LocalDeme, + LHSLevelConfig: LHSDeme, + SobolLevelConfig: SobolDeme, +} def init_from_config( @@ -34,34 +40,17 @@ def init_from_config( logger: FilteringBoundLogger, random_seed: int = None, parent_deme: AbstractDeme | None = None, -) -> AbstractDeme: - args = { - "id": new_id, - "level": target_level, - "config": config, - "started_at": metaepoch_count, - "sprout_seed": sprout_seed, - "logger": logger, - } - child: AbstractDeme - if isinstance(config, DELevelConfig): - child = DEDeme(**args) - elif isinstance(config, SHADELevelConfig): - child = SHADEDeme(**args) - elif isinstance(config, EALevelConfig): - child = EADeme(**args) - elif isinstance(config, CMALevelConfig): - args["x0"] = sprout_seed - args.pop("sprout_seed", None) - args["random_seed"] = random_seed - args["parent_deme"] = parent_deme - child = CMADeme(**args) - elif isinstance(config, LocalOptimizationConfig): - child = LocalDeme(**args) - elif isinstance(config, LHSLevelConfig): - args["random_seed"] = random_seed - child = LHSDeme(**args) - elif isinstance(config, SobolLevelConfig): - args["random_seed"] = random_seed - child = SobolDeme(**args) - return child + config_class_to_deme_class: dict[type[BaseLevelConfig], type[AbstractDeme]] = {}, +): + deme_init_args = DemeInitArgs( + id=new_id, + level=target_level, + config=config, + started_at=metaepoch_count, + sprout_seed=sprout_seed, + logger=logger, + random_seed=random_seed, + parent_deme=parent_deme, + ) + merged_config_class_to_deme_class = config_class_to_deme_class | CONFIG_CLASS_TO_DEME_CLASS + return merged_config_class_to_deme_class[type(config)](deme_init_args) # type: ignore[abstract] diff --git a/pyhms/demes/lhs_deme.py b/pyhms/demes/lhs_deme.py index 773bb9b..a84cf10 100644 --- a/pyhms/demes/lhs_deme.py +++ b/pyhms/demes/lhs_deme.py @@ -2,24 +2,18 @@ from ..config import LHSLevelConfig from ..core.individual import Individual -from ..logging_ import FilteringBoundLogger -from .abstract_deme import AbstractDeme +from .abstract_deme import AbstractDeme, DemeInitArgs class LHSDeme(AbstractDeme): def __init__( self, - id: str, - level: int, - config: LHSLevelConfig, - logger: FilteringBoundLogger, - started_at: int = 0, - sprout_seed: Individual = None, - random_seed: int = None, + deme_init_args: DemeInitArgs, ) -> None: - super().__init__(id, level, config, logger, started_at, sprout_seed) + super().__init__(deme_init_args) + config: LHSLevelConfig = deme_init_args.config # type: ignore[assignment] self._pop_size = config.pop_size - self.sampler = LatinHypercube(d=len(config.bounds), seed=random_seed) + self.sampler = LatinHypercube(d=len(config.bounds), seed=deme_init_args.random_seed) self.lower_bounds = config.bounds[:, 0] self.upper_bounds = config.bounds[:, 1] self.run() diff --git a/pyhms/demes/local_deme.py b/pyhms/demes/local_deme.py index ce4d55f..13796c5 100644 --- a/pyhms/demes/local_deme.py +++ b/pyhms/demes/local_deme.py @@ -1,24 +1,19 @@ from scipy import optimize as sopt -from structlog.typing import FilteringBoundLogger from ..config import LocalOptimizationConfig from ..core.individual import Individual -from .abstract_deme import AbstractDeme +from .abstract_deme import AbstractDeme, DemeInitArgs class LocalDeme(AbstractDeme): def __init__( self, - id: str, - level: int, - config: LocalOptimizationConfig, - logger: FilteringBoundLogger, - sprout_seed: Individual, - started_at=0, + deme_init_args: DemeInitArgs, ) -> None: - super().__init__(id, level, config, logger, started_at) + super().__init__(deme_init_args) + config: LocalOptimizationConfig = deme_init_args.config # type: ignore[assignment] self._method = config.method - self._sprout_seed = sprout_seed + self._sprout_seed = deme_init_args.sprout_seed self._n_evals = 0 starting_pop = [self._sprout_seed] self._history.append([starting_pop]) diff --git a/pyhms/demes/shade_deme.py b/pyhms/demes/shade_deme.py index a656edd..e6139cd 100644 --- a/pyhms/demes/shade_deme.py +++ b/pyhms/demes/shade_deme.py @@ -1,43 +1,37 @@ -from structlog.typing import FilteringBoundLogger - from ..config import SHADELevelConfig from ..core.individual import Individual from ..initializers import sample_normal, sample_uniform -from .abstract_deme import AbstractDeme +from .abstract_deme import AbstractDeme, DemeInitArgs from .single_pop_eas.de import SHADE class SHADEDeme(AbstractDeme): def __init__( self, - id: str, - level: int, - config: SHADELevelConfig, - logger: FilteringBoundLogger, - started_at: int = 0, - sprout_seed: Individual = None, + deme_init_args: DemeInitArgs, ) -> None: - super().__init__(id, level, config, logger, started_at, sprout_seed) + super().__init__(deme_init_args) + config: SHADELevelConfig = deme_init_args.config # type: ignore[assignment] self._init_pop_size = config.pop_size self._pop_size = config.pop_size self._generations = config.generations self._sample_std_dev = config.sample_std_dev self._shade = SHADE(config.memory_size, config.pop_size) - if sprout_seed is None: + if deme_init_args.sprout_seed is None: starting_pop = Individual.create_population( self._pop_size, initialize=sample_uniform(bounds=self._bounds), problem=self._problem, ) else: - x = sprout_seed.genome + x0 = deme_init_args.sprout_seed.genome starting_pop = Individual.create_population( self._pop_size - 1, - initialize=sample_normal(x, self._sample_std_dev, bounds=self._bounds), + initialize=sample_normal(x0, self._sample_std_dev, bounds=self._bounds), problem=self._problem, ) - seed_ind = Individual(x, problem=self._problem) + seed_ind = Individual(x0, problem=self._problem) starting_pop.append(seed_ind) Individual.evaluate_population(starting_pop) diff --git a/pyhms/demes/sobol_deme.py b/pyhms/demes/sobol_deme.py index bd85e87..e5acbb4 100644 --- a/pyhms/demes/sobol_deme.py +++ b/pyhms/demes/sobol_deme.py @@ -2,24 +2,18 @@ from ..config import SobolLevelConfig from ..core.individual import Individual -from ..logging_ import FilteringBoundLogger -from .abstract_deme import AbstractDeme +from .abstract_deme import AbstractDeme, DemeInitArgs class SobolDeme(AbstractDeme): def __init__( self, - id: str, - level: int, - config: SobolLevelConfig, - logger: FilteringBoundLogger, - started_at: int = 0, - sprout_seed: Individual = None, - random_seed: int = None, + deme_init_args: DemeInitArgs, ) -> None: - super().__init__(id, level, config, logger, started_at, sprout_seed) + super().__init__(deme_init_args) + config: SobolLevelConfig = deme_init_args.config # type: ignore[assignment] self._pop_size = config.pop_size - self.sampler = Sobol(d=len(config.bounds), scramble=True, seed=random_seed) + self.sampler = Sobol(d=len(config.bounds), scramble=True, seed=deme_init_args.random_seed) self.lower_bounds = config.bounds[:, 0] self.upper_bounds = config.bounds[:, 1] self.run() diff --git a/pyhms/tree.py b/pyhms/tree.py index 1ad84dc..ad92612 100644 --- a/pyhms/tree.py +++ b/pyhms/tree.py @@ -17,7 +17,7 @@ from .core.problem import StatsGatheringProblem, get_function_problem from .demes.abstract_deme import AbstractDeme from .demes.cma_deme import CMADeme -from .demes.initialize import init_from_config, init_root +from .demes.initialize import init_from_config from .logging_ import DEFAULT_LOGGING_LEVEL, get_logger from .sprout.sprout_candidates import DemeCandidates from .sprout.sprout_mechanisms import SproutMechanism @@ -57,7 +57,16 @@ def __init__(self, config: TreeConfig) -> None: self._random_seed = None self._levels: list[list[AbstractDeme]] = [[] for _ in range(nlevels)] - root_deme = init_root(config.levels[0], self._logger) + root_deme = init_from_config( + config=config.levels[0], + new_id="root", + target_level=0, + metaepoch_count=0, + sprout_seed=None, + logger=self._logger, + random_seed=self._random_seed, + config_class_to_deme_class=self.config.config_class_to_deme_class, + ) self._levels[0].append(root_deme) @property @@ -173,14 +182,15 @@ def _do_sprout(self, deme_seeds: dict[AbstractDeme, DemeCandidates]) -> None: config = self.config.levels[target_level] child = init_from_config( - config, - new_id, - target_level, - self.metaepoch_count, + config=config, + new_id=new_id, + target_level=target_level, + metaepoch_count=self.metaepoch_count, sprout_seed=ind, logger=self._logger, random_seed=self._random_seed, parent_deme=deme, + config_class_to_deme_class=self.config.config_class_to_deme_class, ) deme.add_child(child) self._levels[target_level].append(child) @@ -628,18 +638,25 @@ def plot_population( height=1000, template="plotly_white", font=dict(size=16), + margin=dict(t=30, pad=4), + xaxis=dict(range=[bounds[0][0], bounds[0][1]]), + yaxis=dict(range=[bounds[1][0], bounds[1][1]]), coloraxis_colorbar=dict( title=dict( text="f(x, y)", side="right", font=dict(size=16), ), - x=1.17, + x=1.02, y=0.5, len=0.8, ), legend=dict( - y=0.5, + orientation="h", + yanchor="bottom", + y=1.02, + xanchor="center", + x=0.5, ), ) if filepath: diff --git a/test/test_tree_format.py b/test/test_tree_format.py index 24d80e8..7d8a743 100644 --- a/test/test_tree_format.py +++ b/test/test_tree_format.py @@ -1,6 +1,7 @@ import numpy as np from pyhms.config import CMALevelConfig from pyhms.core.individual import Individual +from pyhms.demes.abstract_deme import DemeInitArgs from pyhms.demes.cma_deme import CMADeme from pyhms.logging_ import get_logger from pyhms.tree import DemeTree @@ -17,13 +18,14 @@ def test_format_deme(): sigma0=1.0, generations=1, ) - deme = CMADeme( + deme_init_args = DemeInitArgs( id="0", level=0, config=config, logger=logger, - x0=Individual(genome=np.array([0, 0]), problem=SQUARE_PROBLEM), + sprout_seed=Individual(genome=np.array([0, 0]), problem=SQUARE_PROBLEM), ) + deme = CMADeme(deme_init_args) formatted_deme = format_deme(deme) assert formatted_deme.startswith("CMADeme 0") assert "sprout: (0.00, 0.00);" in formatted_deme