diff --git a/decent_bench/agents.py b/decent_bench/agents.py index 597f3de..51ebb10 100644 --- a/decent_bench/agents.py +++ b/decent_bench/agents.py @@ -11,15 +11,21 @@ class Agent: - """Agent with unique id, local cost function, and activation scheme.""" + """Agent with unique id, local cost function, activation scheme and state snapshot period.""" + + def __init__(self, agent_id: int, cost: Cost, activation: AgentActivationScheme, state_snapshot_period: int): + if state_snapshot_period <= 0: + raise ValueError("state_snapshot_period must be a positive integer") - def __init__(self, agent_id: int, cost: Cost, activation: AgentActivationScheme): self._id = agent_id self._cost = cost self._activation = activation - self._x_history: list[Array] = [] + self._state_snapshot_period = state_snapshot_period + self._current_x: Array | None = None + self._x_history: dict[int, Array] = {} self._auxiliary_variables: dict[str, Array] = {} self._received_messages: dict[Agent, Array] = {} + self._n_x_updates = 0 self._n_sent_messages = 0 self._n_received_messages = 0 self._n_sent_messages_dropped = 0 @@ -55,24 +61,20 @@ def x(self) -> Array: """ Local optimization variable x. - Warning: - Do not use in-place operations (``+=``, ``-=``, ``*=``, etc.) on this property. - In-place operations will corrupt the optimization history by modifying all - historical values. Always use ``agent.x = agent.x + value`` instead of - ``agent.x += value``. Does not affect the outcome of the optimization, but - will affect logging and metrics that depend on the optimization history. - Raises: RuntimeError: if x is retrieved before being set or initialized """ - if not self._x_history: + if self._current_x is None: raise RuntimeError("x must be initialized before being accessed") - return self._x_history[-1] + return self._current_x @x.setter def x(self, x: Array) -> None: - self._x_history.append(x) + self._n_x_updates += 1 + self._current_x = x + if self._n_x_updates % self._state_snapshot_period == 0: + self._x_history[self._n_x_updates] = iop.copy(x) @property def messages(self) -> Mapping[Agent, Array]: @@ -106,7 +108,9 @@ def initialize( if x is not None: if iop.shape(x) != self.cost.shape: raise ValueError(f"Initialized x has shape {iop.shape(x)}, expected {self.cost.shape}") - self._x_history = [iop.copy(x)] + self._x_history = {0: iop.copy(x)} + self._current_x = iop.copy(x) + self._n_x_updates = 0 if aux_vars: self._auxiliary_variables = {k: iop.copy(v) for k, v in aux_vars.items()} if received_msgs: @@ -138,7 +142,8 @@ class AgentMetricsView: """Immutable view of agent that exposes useful properties for calculating metrics.""" cost: Cost - x_history: list[Array] + x_history: dict[int, Array] + n_x_updates: int n_function_calls: int n_gradient_calls: int n_hessian_calls: int @@ -150,9 +155,14 @@ class AgentMetricsView: @staticmethod def from_agent(agent: Agent) -> AgentMetricsView: """Create from agent.""" + # Append the last x if not already recorded + if agent._current_x is not None and agent._n_x_updates not in agent._x_history: # noqa: SLF001 + agent._x_history[agent._n_x_updates] = iop.copy(agent._current_x) # noqa: SLF001 + return AgentMetricsView( cost=agent.cost, x_history=agent._x_history, # noqa: SLF001 + n_x_updates=agent._n_x_updates, # noqa: SLF001 n_function_calls=agent._n_function_calls, # noqa: SLF001 n_gradient_calls=agent._n_gradient_calls, # noqa: SLF001 n_hessian_calls=agent._n_hessian_calls, # noqa: SLF001 diff --git a/decent_bench/benchmark.py b/decent_bench/benchmark.py index 1ce7d1b..3a953cd 100644 --- a/decent_bench/benchmark.py +++ b/decent_bench/benchmark.py @@ -31,6 +31,10 @@ def benchmark( table_metrics: list[TableMetric] = DEFAULT_TABLE_METRICS, table_fmt: Literal["grid", "latex"] = "grid", *, + plot_grid: bool = True, + plot_path: str | None = None, + computational_cost: pm.ComputationalCost | None = None, + x_axis_scaling: float = 1e-4, n_trials: int = 30, confidence_level: float = 0.95, log_level: int = logging.INFO, @@ -38,6 +42,7 @@ def benchmark( progress_step: int | None = None, show_speed: bool = False, show_trial: bool = False, + compare_iterations_and_computational_cost: bool = False, ) -> None: """ Benchmark distributed algorithms. @@ -51,6 +56,13 @@ def benchmark( table_metrics: metrics to tabulate as confidence intervals after the execution, defaults to :const:`~decent_bench.metrics.table_metrics.DEFAULT_TABLE_METRICS` table_fmt: table format, grid is suitable for the terminal while latex can be copy-pasted into a latex document + plot_grid: whether to show grid lines on the plots + plot_path: optional file path to save the generated plot as an image file (e.g., "plots.png"). If ``None``, + the plot will only be displayed + computational_cost: computational cost settings for plot metrics, if ``None`` x-axis will be iterations instead + of computational cost + x_axis_scaling: scaling factor for computational cost x-axis, used to convert the cost units into more + manageable units for plotting. Only used if ``computational_cost`` is provided. n_trials: number of times to run each algorithm on the benchmark problem, running more trials improves the statistical results, at least 30 trials are recommended for the central limit theorem to apply confidence_level: confidence level of the confidence intervals @@ -64,15 +76,27 @@ def benchmark( If `None`, the progress bar uses 1 unit per trial. show_speed: whether to show speed (iterations/second) in the progress bar. show_trial: whether to show which trials are currently running in the progress bar. + compare_iterations_and_computational_cost: whether to plot both metric vs computational cost and + metric vs iterations. Only used if ``computational_cost`` is provided. Note: If ``progress_step`` is too small performance may degrade due to the overhead of updating the progress bar too often. + Computational cost can be interpreted as the cost of running the algorithm on a specific hardware setup. + Therefore the computational cost could be seen as the number of operations performed (similar to FLOPS) but + weighted by the time or energy it takes to perform them on the specific hardware. + + .. include:: snippets/computational_cost.rst + + If ``computational_cost`` is provided and ``compare_iterations_and_computational_cost`` is ``True``, each metric + will be plotted twice: once against computational cost and once against iterations. + Computational cost plots will be shown on the left and iteration plots on the right. + """ manager = Manager() log_listener = logger.start_log_listener(manager, log_level) - LOGGER.info("Starting benchmark execution, progress bar increments with each completed trial ") + LOGGER.info("Starting benchmark execution ") with Status("Generating initial network state"): nw_init_state = create_distributed_network(benchmark_problem) LOGGER.debug(f"Nr of agents: {len(nw_init_state.agents())}") @@ -82,10 +106,17 @@ def benchmark( resulting_agent_states: dict[Algorithm, list[list[AgentMetricsView]]] = {} for alg, networks in resulting_nw_states.items(): resulting_agent_states[alg] = [[AgentMetricsView.from_agent(a) for a in nw.agents()] for nw in networks] - with Status("Creating table"): - tm.tabulate(resulting_agent_states, benchmark_problem, table_metrics, confidence_level, table_fmt) - with Status("Creating plot"): - pm.plot(resulting_agent_states, benchmark_problem, plot_metrics) + tm.tabulate(resulting_agent_states, benchmark_problem, table_metrics, confidence_level, table_fmt) + pm.plot( + resulting_agent_states, + benchmark_problem, + plot_metrics, + computational_cost, + x_axis_scaling, + compare_iterations_and_computational_cost, + plot_path, + plot_grid, + ) LOGGER.info("Benchmark execution complete, thanks for using decent-bench") log_listener.stop() diff --git a/decent_bench/benchmark_problem.py b/decent_bench/benchmark_problem.py index ad7377a..b479916 100644 --- a/decent_bench/benchmark_problem.py +++ b/decent_bench/benchmark_problem.py @@ -41,6 +41,7 @@ class BenchmarkProblem: network_structure: graph defining how agents are connected x_optimal: solution that minimizes the sum of the cost functions, used for calculating metrics costs: local cost functions, each one is given to one agent + agent_state_snapshot_period: period for recording agent state snapshots, used for plot metrics agent_activations: setting for agent activation/participation, each scheme is applied to one agent message_compression: message compression setting message_noise: message noise setting @@ -51,6 +52,7 @@ class BenchmarkProblem: network_structure: AnyGraph x_optimal: Array costs: Sequence[Cost] + agent_state_snapshot_period: int agent_activations: Sequence[AgentActivationScheme] message_compression: CompressionScheme message_noise: NoiseScheme @@ -61,6 +63,7 @@ def create_regression_problem( cost_cls: type[LinearRegressionCost | LogisticRegressionCost], *, n_agents: int = 100, + agent_state_snapshot_period: int = 1, n_neighbors_per_agent: int = 3, asynchrony: bool = False, compression: bool = False, @@ -73,6 +76,7 @@ def create_regression_problem( Args: cost_cls: type of cost function n_agents: number of agents + agent_state_snapshot_period: period for recording agent state snapshots, used for plot metrics n_neighbors_per_agent: number of neighbors per agent asynchrony: if true, agents only have a 50% probability of being active/participating at any given time compression: if true, messages are rounded to 4 significant digits @@ -100,6 +104,7 @@ def create_regression_problem( return BenchmarkProblem( network_structure=network_structure, costs=costs, + agent_state_snapshot_period=agent_state_snapshot_period, x_optimal=x_optimal, agent_activations=agent_activations, message_compression=message_compression, diff --git a/decent_bench/metrics/metric_utils.py b/decent_bench/metrics/metric_utils.py index dfac782..febf36a 100644 --- a/decent_bench/metrics/metric_utils.py +++ b/decent_bench/metrics/metric_utils.py @@ -6,6 +6,7 @@ from numpy import linalg as la from numpy.linalg import LinAlgError from numpy.typing import NDArray +from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeRemainingColumn import decent_bench.utils.interoperability as iop from decent_bench.agents import AgentMetricsView @@ -13,6 +14,24 @@ from decent_bench.utils.array import Array +class MetricProgressBar(Progress): + """ + Progress bar for metric calculations. + + Make sure to set the field *status* in the task to show custom status messages. + + """ + + def __init__(self) -> None: + super().__init__( + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + TimeRemainingColumn(elapsed_when_finished=True), + TextColumn("{task.fields[status]}"), + ) + + def single(values: Sequence[float]) -> float: """ Assert that *values* contain exactly one element and return it. @@ -37,7 +56,11 @@ def x_mean(agents: tuple[AgentMetricsView, ...], iteration: int = -1) -> Array: ValueError: if no agent reached *iteration* """ - all_x_at_iter = [a.x_history[iteration] for a in agents if len(a.x_history) > iteration] + if iteration == -1: + all_x_at_iter = [a.x_history[max(a.x_history)] for a in agents if len(a.x_history) > 0] + else: + all_x_at_iter = [a.x_history[iteration] for a in agents if iteration in a.x_history] + if len(all_x_at_iter) == 0: raise ValueError(f"No agent reached iteration {iteration}") @@ -56,7 +79,7 @@ def regret(agents: list[AgentMetricsView], problem: BenchmarkProblem, iteration: mean_x = x_mean(tuple(agents), iteration) optimal_cost = sum(a.cost.function(x_opt) for a in agents) actual_cost = sum(a.cost.function(mean_x) for a in agents) - return abs(optimal_cost - actual_cost) + return actual_cost - optimal_cost def gradient_norm(agents: list[AgentMetricsView], iteration: int = -1) -> float: @@ -83,7 +106,7 @@ def x_error(agent: AgentMetricsView, problem: BenchmarkProblem) -> NDArray[float where :math:`\mathbf{x}_k` is the agent's local x at iteration k, and :math:`\mathbf{x}^\star` is the optimal x defined in the *problem*. """ - x_per_iteration = np.asarray([iop.to_numpy(x) for x in agent.x_history]) + x_per_iteration = np.asarray([iop.to_numpy(x) for _, x in sorted(agent.x_history.items())]) opt_x = problem.x_optimal errors: NDArray[float64] = la.norm(x_per_iteration - opt_x, axis=tuple(range(1, x_per_iteration.ndim))) return errors @@ -131,3 +154,22 @@ def iterative_convergence_rate_and_order(agent: AgentMetricsView, problem: Bench except LinAlgError: rate, order = np.nan, np.nan return rate, order + + +def common_sorted_iterations(agents: Sequence[AgentMetricsView]) -> list[int]: + """ + Get a sorted list of all common iterations reached by agents in *agents*. + + Since the agents can sample their states periodically, and may sample at different iterations, + this function returns only the iterations that are common to all agents. These iterations can then be used + to compute metrics that require synchronized iterations. + + Args: + agents: sequence of agents to get the common iterations from + + Returns: + sorted list of iterations reached by all agents + + """ + common_iters = set.intersection(*(set(a.x_history.keys()) for a in agents)) if agents else set() + return sorted(common_iters) diff --git a/decent_bench/metrics/plot_metrics.py b/decent_bench/metrics/plot_metrics.py index 0a0e30c..8729b87 100644 --- a/decent_bench/metrics/plot_metrics.py +++ b/decent_bench/metrics/plot_metrics.py @@ -1,13 +1,14 @@ import math -import random import warnings from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import Sequence +from dataclasses import dataclass import matplotlib.pyplot as plt import numpy as np from matplotlib.axes import Axes as SubPlot +from matplotlib.figure import Figure import decent_bench.metrics.metric_utils as utils from decent_bench.agents import AgentMetricsView @@ -19,6 +20,17 @@ Y = float +@dataclass +class ComputationalCost: + """Computational costs associated with an algorithm for plot metrics.""" + + function: float = 1.0 + gradient: float = 1.0 + hessian: float = 1.0 + proximal: float = 1.0 + communication: float = 1.0 + + class PlotMetric(ABC): """ Metric to plot at the end of the benchmarking execution. @@ -35,12 +47,7 @@ def __init__(self, *, x_log: bool = False, y_log: bool = True): @property @abstractmethod - def x_label(self) -> str: - """Label for the x-axis.""" - - @property - @abstractmethod - def y_label(self) -> str: + def plot_description(self) -> str: """Label for the y-axis.""" @abstractmethod @@ -61,12 +68,11 @@ class RegretPerIteration(PlotMetric): its calculation. """ - x_label: str = "iteration" - y_label: str = "regret" + plot_description: str = "regret" def get_data_from_trial(self, agents: list[AgentMetricsView], problem: BenchmarkProblem) -> list[tuple[X, Y]]: # noqa: D102 - iter_reached_by_all = min(len(a.x_history) for a in agents) - return [(i, utils.regret(agents, problem, i)) for i in range(iter_reached_by_all)] + # Determine the set of recorded iterations common to all agents and use those + return [(i, utils.regret(agents, problem, i)) for i in utils.common_sorted_iterations(agents)] class GradientNormPerIteration(PlotMetric): @@ -82,12 +88,11 @@ class GradientNormPerIteration(PlotMetric): included in the calculation. """ - x_label: str = "iteration" - y_label: str = "gradient norm" + plot_description: str = "gradient norm" def get_data_from_trial(self, agents: list[AgentMetricsView], _: BenchmarkProblem) -> list[tuple[X, Y]]: # noqa: D102 - iter_reached_by_all = min(len(a.x_history) for a in agents) - return [(i, utils.gradient_norm(agents, i)) for i in range(iter_reached_by_all)] + # Determine the set of recorded iterations common to all agents and use those + return [(i, utils.gradient_norm(agents, i)) for i in utils.common_sorted_iterations(agents)] DEFAULT_PLOT_METRICS = [ @@ -102,14 +107,38 @@ def get_data_from_trial(self, agents: list[AgentMetricsView], _: BenchmarkProble """ PLOT_METRICS_DOC_LINK = "https://decent-bench.readthedocs.io/en/latest/api/decent_bench.metrics.plot_metrics.html" -COLORS = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf"] -MARKERS = ["o", "s", "v", "^", "*", "D", "H", "<", ">", "p"] +X_LABELS = { + "iterations": "iterations", + "computational_cost": "time (computational cost units)", +} +COLORS = [ + "#1f77b4", + "#ff7f0e", + "#2ca02c", + "#d62728", + "#9467bd", + "#8c564b", + "#e377c2", + "#7f7f7f", + "#bcbd22", + "#17becf", + "#34495e", + "#16a085", + "#686901", +] +MARKERS = ["o", "s", "v", "^", "*", "D", "H", "<", ">", "p", "P", "X"] +STYLES = ["-", ":", "--", "-.", (5, (10, 3)), (0, (5, 10)), (0, (3, 1, 1, 1))] -def plot( +def plot( # noqa: PLR0917 resulting_agent_states: dict[Algorithm, list[list[AgentMetricsView]]], problem: BenchmarkProblem, metrics: list[PlotMetric], + computational_cost: ComputationalCost | None, + x_axis_scaling: float = 1e-4, + compare_iterations_and_computational_cost: bool = False, + plot_path: str | None = None, + plot_grid: bool = True, ) -> None: """ Plot the execution results with one subplot per metric. @@ -121,55 +150,267 @@ def plot( problem: benchmark problem whose properties, e.g. :attr:`~decent_bench.benchmark_problem.BenchmarkProblem.x_optimal`, are used for metric calculations metrics: metrics to calculate and plot - - Raises: - RuntimeError: if the figure manager can't be retrieved + computational_cost: computational cost settings for plot metrics, if ``None`` x-axis will be iterations instead + of computational cost + x_axis_scaling: scaling factor for computational cost x-axis, used to convert the cost units into more + manageable units for plotting. Only used if ``computational_cost`` is provided + compare_iterations_and_computational_cost: whether to plot both metric vs computational cost and + metric vs iterations. Only used if ``computational_cost`` is provided + plot_path: optional file path to save the generated plot as an image file (e.g., "plots.png"). If ``None``, + the plot will only be displayed + plot_grid: whether to show grid lines on the plots + + Note: + Computational cost can be interpreted as the cost of running the algorithm on a specific hardware setup. + Therefore the computational cost could be seen as the number of operations performed (similar to FLOPS) but + weighted by the time or energy it takes to perform them on the specific hardware. + + .. include:: snippets/computational_cost.rst """ if not metrics: return LOGGER.info(f"Plot metric definitions can be found here: {PLOT_METRICS_DOC_LINK}") - metric_subplots: list[tuple[PlotMetric, SubPlot]] = _create_metric_subplots(metrics) - for metric, subplot in metric_subplots: - for i, (alg, agent_states) in enumerate(resulting_agent_states.items()): - color = COLORS[i] if i < len(COLORS) else [random.random() for _ in range(3)] - marker = MARKERS[i] if i < len(MARKERS) else random.choice(MARKERS) - data_per_trial: list[Sequence[tuple[X, Y]]] = _get_data_per_trial(agent_states, problem, metric) - flattened_data: list[tuple[X, Y]] = [d for trial in data_per_trial for d in trial] - if not np.isfinite(flattened_data).all(): - msg = f"Skipping plot {metric.y_label}/{metric.x_label} for {alg.name}: found nan or inf in datapoints." - LOGGER.warning(msg) - continue - mean_curve: Sequence[tuple[X, Y]] = _calculate_mean_curve(data_per_trial) - x, y_mean = zip(*mean_curve, strict=True) - subplot.plot(x, y_mean, label=alg.name, color=color, marker=marker, markevery=max(1, int(len(x) / 20))) - y_min, y_max = _calculate_envelope(data_per_trial) - subplot.fill_between(x, y_min, y_max, color=color, alpha=0.1) - subplot.legend() - manager = plt.get_current_fig_manager() - if not manager: - raise RuntimeError("Something went wrong, did not receive a FigureManager...") - plt.tight_layout() - plt.show() + + if len(metrics) > 4: + LOGGER.warning( + f"Plotting {len(metrics)} (> 4) metrics may result in a cluttered figure. " + "Consider reducing the number of metrics for better readability." + ) + + did_plot = False + use_cost = computational_cost is not None + two_columns = use_cost and compare_iterations_and_computational_cost + fig, metric_subplots = _create_metric_subplots( + metrics, + use_cost, + compare_iterations_and_computational_cost, + plot_grid, + ) + with utils.MetricProgressBar() as progress: + plot_task = progress.add_task( + "Generating plots", + total=len(metric_subplots) * len(resulting_agent_states), + status="", + ) + x_label = X_LABELS["computational_cost" if use_cost else "iterations"] + for metric_index in range(len(metrics)): + progress.update( + plot_task, + status=f"Task: {metrics[metric_index].plot_description} vs {x_label}", + ) + for i, (alg, agent_states) in enumerate(resulting_agent_states.items()): + data_per_trial: list[Sequence[tuple[X, Y]]] = _get_data_per_trial( + agent_states, problem, metrics[metric_index] + ) + if not _is_finite(data_per_trial): + msg = ( + f"Skipping plot {metrics[metric_index].plot_description}/{x_label} " + f"for {alg.name}: found nan or inf in datapoints." + ) + LOGGER.warning(msg) + progress.advance(plot_task, 2 if two_columns else 1) + continue + _plot( + metric_subplots, + data_per_trial, + computational_cost, + compare_iterations_and_computational_cost, + x_axis_scaling, + agent_states, + alg, + metric_index, + i, + ) + did_plot = True + progress.advance(plot_task, 2 if two_columns else 1) + progress.update(plot_task, status="Finalizing plots") + + if not did_plot: + LOGGER.warning("No plots were generated due to invalid data.") + return + + _show_figure(fig, metric_subplots, two_columns, plot_path) -def _create_metric_subplots(metrics: list[PlotMetric]) -> list[tuple[PlotMetric, SubPlot]]: - subplots_per_row = 2 - n_metrics = len(metrics) - n_rows = math.ceil(n_metrics / subplots_per_row) - fig, subplots = plt.subplots(nrows=n_rows, ncols=subplots_per_row) - subplots = subplots.flatten() - for sp in subplots[n_metrics:]: +def _create_metric_subplots( + metrics: list[PlotMetric], + use_cost: bool, + compare_iterations_and_computational_cost: bool, + plot_grid: bool, +) -> tuple[Figure, list[SubPlot]]: + n_cols = 2 if use_cost and compare_iterations_and_computational_cost else 1 + n_plots = len(metrics) * n_cols + n_rows = math.ceil(n_plots / n_cols) + + fig, subplot_axes = plt.subplots( + nrows=n_rows, + ncols=n_cols, + sharex="col", + sharey="row", + layout="constrained", + ) + if isinstance(subplot_axes, SubPlot): + subplots: list[SubPlot] = [subplot_axes] + else: + subplots = subplot_axes.flatten() + + if subplots is None: + raise RuntimeError("Something went wrong, did not receive subplot axes...") + + for sp in subplots[n_plots + n_cols :]: fig.delaxes(sp) - metric_subplots = list(zip(metrics, subplots[:n_metrics], strict=True)) - for metric, sp in metric_subplots: - sp.set_xlabel(metric.x_label) - sp.set_ylabel(metric.y_label) + + for i in range(n_plots): + metric = metrics[i // (2 if n_cols == 2 else 1)] + sp = subplots[i] + + # Only set x label for subplots in the last row + if i // n_cols == n_rows - 1: + # For comparison mode, right column shows iterations, left shows cost + if n_cols == 2: + sp.set_xlabel(X_LABELS["iterations"] if i % 2 == 1 else X_LABELS["computational_cost"]) + else: + # Single column mode: show cost if enabled, otherwise iterations + sp.set_xlabel(X_LABELS["computational_cost" if use_cost else "iterations"]) + + # Only set y label for left column subplots + if i % n_cols == 0: + sp.set_ylabel(metric.plot_description) + if metric.x_log: sp.set_xscale("log") if metric.y_log: sp.set_yscale("log") - return metric_subplots + + if plot_grid: + sp.grid(True, which="major", linestyle="--", linewidth=0.5, alpha=0.7) # noqa: FBT003 + + return fig, subplots[:n_plots] + + +def _show_figure( + fig: Figure, + metric_subplots: list[SubPlot], + two_columns: bool, + plot_path: str | None = None, +) -> None: + manager = plt.get_current_fig_manager() + if not manager: + raise RuntimeError("Something went wrong, did not receive a FigureManager...") + + # Create a single legend at the top of the figure + handles, labels = metric_subplots[0].get_legend_handles_labels() + label_cols = min(len(labels), 4 if two_columns else 3) + + # Create the legend to get the height of the legend box + fig.legend( + handles, + labels, + loc="outside upper center", + ncol=label_cols, + frameon=True, + ) + + if plot_path is not None: + fig.savefig(plot_path, dpi=300) + LOGGER.info(f"Saved plot to: {plot_path}") + + plt.show() + + +def _is_finite(data_per_trial: list[Sequence[tuple[X, Y]]]) -> bool: + flattened_data: list[tuple[X, Y]] = [d for trial in data_per_trial for d in trial] + return np.isfinite(flattened_data).all().item() + + +def _plot( # noqa: PLR0917 + metric_subplots: list[SubPlot], + data_per_trial: list[Sequence[tuple[X, Y]]], + computational_cost: ComputationalCost | None, + compare_iterations_and_computational_cost: bool, + x_axis_scaling: float, + agent_states: list[list[AgentMetricsView]], + alg: Algorithm, + metric_index: int, + iteration: int, +) -> None: + use_cost = computational_cost is not None + subplot_idx = metric_index * (2 if use_cost and compare_iterations_and_computational_cost else 1) + + mean_curve: Sequence[tuple[X, Y]] = _calculate_mean_curve(data_per_trial) + x, y_mean = zip(*mean_curve, strict=True) + y_min, y_max = _calculate_envelope(data_per_trial) + if computational_cost is not None: + total_computational_cost = _calc_total_cost(agent_states, computational_cost) + x_computational = tuple(val * total_computational_cost * x_axis_scaling for val in x) + if compare_iterations_and_computational_cost: + # Plot value vs iterations subplot first + iter_idx = metric_index * 2 + 1 + _plot_subplot(metric_subplots[iter_idx], x, y_mean, y_min, y_max, alg.name, iteration) + x = x_computational + _plot_subplot(metric_subplots[subplot_idx], x, y_mean, y_min, y_max, alg.name, iteration) + + +def _plot_subplot( # noqa: PLR0917 + subplot: SubPlot, + x: Sequence[float], + y_mean: Sequence[float], + y_min: Sequence[float], + y_max: Sequence[float], + label: str, + iteration: int, +) -> None: + marker, linestyle, color = _get_marker_style_color(iteration) + subplot.plot( + x, + y_mean, + label=label, + color=color, + marker=marker, + linestyle=linestyle, + markevery=max(1, int(len(x) / 10)), + ) + subplot.fill_between(x, y_min, y_max, color=color, alpha=0.1) + + +def _get_marker_style_color( + index: int, +) -> tuple[str, Sequence[int | tuple[int, int, int, int] | str | tuple[int, int]], str]: + """ + Get deterministic unique marker, line style, and color for a given index. + + Cycles through all combinations to ensure the first n indices (where n = + len(MARKERS) * len(STYLES)) are unique. Colors cycle based on index, + markers cycle first, then styles to maximize marker distinctiveness for B&W printing. + """ + # Calculate total unique combinations + n_combinations = len(MARKERS) * len(STYLES) + + # Reduce index to valid range + idx = index % n_combinations + + color_idx = index % len(COLORS) + marker_idx = idx % len(MARKERS) + style_idx = (idx // len(MARKERS)) % len(STYLES) + + return MARKERS[marker_idx], STYLES[style_idx], COLORS[color_idx] + + +def _calc_total_cost(agent_states: list[list[AgentMetricsView]], computational_cost: ComputationalCost) -> float: + mean_function_calls = np.mean([a.n_function_calls for agents in agent_states for a in agents]) + mean_gradient_calls = np.mean([a.n_gradient_calls for agents in agent_states for a in agents]) + mean_hessian_calls = np.mean([a.n_hessian_calls for agents in agent_states for a in agents]) + mean_proximal_calls = np.mean([a.n_proximal_calls for agents in agent_states for a in agents]) + mean_communication_calls = np.mean([a.n_sent_messages for agents in agent_states for a in agents]) + + return float( + computational_cost.function * mean_function_calls + + computational_cost.gradient * mean_gradient_calls + + computational_cost.hessian * mean_hessian_calls + + computational_cost.proximal * mean_proximal_calls + + computational_cost.communication * mean_communication_calls + ) def _get_data_per_trial( diff --git a/decent_bench/metrics/table_metrics.py b/decent_bench/metrics/table_metrics.py index 2bc1130..70e2e2f 100644 --- a/decent_bench/metrics/table_metrics.py +++ b/decent_bench/metrics/table_metrics.py @@ -26,15 +26,24 @@ class TableMetric(ABC): statistics: sequence of statistics such as :func:`min`, :func:`sum`, and :func:`~numpy.average` used for aggregating the data retrieved with :func:`get_data_from_trial` into a single value, each statistic gets its own row in the table + fmt: format string used to format the values in the table, defaults to ".2e". Common formats include: + - ".2e": scientific notation with 2 decimal places + - ".3f": fixed-point notation with 3 decimal places + - ".4g": general format with 4 significant digits + - ".1%": percentage format with 1 decimal place + + Where the integer specifies the precision. + See :meth:`str.format` documentation for details on the format string options. """ - def __init__(self, statistics: list[Statistic]): + def __init__(self, statistics: list[Statistic], fmt: str = ".2e"): self.statistics = statistics + self.fmt = fmt @property @abstractmethod - def description(self) -> str: + def table_description(self) -> str: """Metric description to display in the table.""" @abstractmethod @@ -51,7 +60,7 @@ class Regret(TableMetric): .. include:: snippets/global_cost_error.rst """ - description: str = "regret \n[<1e-9 = exact conv.]" + table_description: str = "regret \n[<1e-9 = exact conv.]" def get_data_from_trial(self, agents: list[AgentMetricsView], problem: BenchmarkProblem) -> tuple[float]: # noqa: D102 return (utils.regret(agents, problem, iteration=-1),) @@ -66,7 +75,7 @@ class GradientNorm(TableMetric): .. include:: snippets/global_gradient_optimality.rst """ - description: str = "gradient norm" + table_description: str = "gradient norm" def get_data_from_trial(self, agents: list[AgentMetricsView], _: BenchmarkProblem) -> tuple[float]: # noqa: D102 return (utils.gradient_norm(agents, iteration=-1),) @@ -85,10 +94,13 @@ class XError(TableMetric): """ - description: str = "x error" + table_description: str = "x error" def get_data_from_trial(self, agents: list[AgentMetricsView], problem: BenchmarkProblem) -> list[float]: # noqa: D102 - return [float(la.norm(iop.to_numpy(problem.x_optimal) - iop.to_numpy(a.x_history[-1]))) for a in agents] + return [ + float(la.norm(iop.to_numpy(problem.x_optimal) - iop.to_numpy(a.x_history[max(a.x_history)]))) + for a in agents + ] class AsymptoticConvergenceOrder(TableMetric): @@ -98,7 +110,7 @@ class AsymptoticConvergenceOrder(TableMetric): .. include:: snippets/asymptotic_convergence_rate_and_order.rst """ - description: str = "asymptotic convergence order" + table_description: str = "asymptotic convergence order" def get_data_from_trial(self, agents: list[AgentMetricsView], problem: BenchmarkProblem) -> list[float]: # noqa: D102 return [utils.asymptotic_convergence_rate_and_order(a, problem)[1] for a in agents] @@ -111,7 +123,7 @@ class AsymptoticConvergenceRate(TableMetric): .. include:: snippets/asymptotic_convergence_rate_and_order.rst """ - description: str = "asymptotic convergence rate" + table_description: str = "asymptotic convergence rate" def get_data_from_trial(self, agents: list[AgentMetricsView], problem: BenchmarkProblem) -> list[float]: # noqa: D102 return [utils.asymptotic_convergence_rate_and_order(a, problem)[0] for a in agents] @@ -124,7 +136,7 @@ class IterativeConvergenceOrder(TableMetric): .. include:: snippets/iterative_convergence_rate_and_order.rst """ - description: str = "iterative convergence order" + table_description: str = "iterative convergence order" def get_data_from_trial(self, agents: list[AgentMetricsView], problem: BenchmarkProblem) -> list[float]: # noqa: D102 return [utils.iterative_convergence_rate_and_order(a, problem)[1] for a in agents] @@ -137,7 +149,7 @@ class IterativeConvergenceRate(TableMetric): .. include:: snippets/iterative_convergence_rate_and_order.rst """ - description: str = "iterative convergence rate" + table_description: str = "iterative convergence rate" def get_data_from_trial(self, agents: list[AgentMetricsView], problem: BenchmarkProblem) -> list[float]: # noqa: D102 return [utils.iterative_convergence_rate_and_order(a, problem)[0] for a in agents] @@ -146,16 +158,16 @@ def get_data_from_trial(self, agents: list[AgentMetricsView], problem: Benchmark class XUpdates(TableMetric): """Number of iterations/updates of x per agent.""" - description: str = "nr x updates" + table_description: str = "nr x updates" def get_data_from_trial(self, agents: list[AgentMetricsView], _: BenchmarkProblem) -> list[float]: # noqa: D102 - return [len(a.x_history) - 1 for a in agents] + return [a.n_x_updates for a in agents] class FunctionCalls(TableMetric): """Number of cost function evaluate calls per agent.""" - description: str = "nr function calls" + table_description: str = "nr function calls" def get_data_from_trial(self, agents: list[AgentMetricsView], _: BenchmarkProblem) -> list[float]: # noqa: D102 return [a.n_function_calls for a in agents] @@ -164,7 +176,7 @@ def get_data_from_trial(self, agents: list[AgentMetricsView], _: BenchmarkProble class GradientCalls(TableMetric): """Number of cost function gradient calls per agent.""" - description: str = "nr gradient calls" + table_description: str = "nr gradient calls" def get_data_from_trial(self, agents: list[AgentMetricsView], _: BenchmarkProblem) -> list[float]: # noqa: D102 return [a.n_gradient_calls for a in agents] @@ -173,7 +185,7 @@ def get_data_from_trial(self, agents: list[AgentMetricsView], _: BenchmarkProble class HessianCalls(TableMetric): """Number of cost function hessian calls per agent.""" - description: str = "nr hessian calls" + table_description: str = "nr hessian calls" def get_data_from_trial(self, agents: list[AgentMetricsView], _: BenchmarkProblem) -> list[float]: # noqa: D102 return [a.n_hessian_calls for a in agents] @@ -182,7 +194,7 @@ def get_data_from_trial(self, agents: list[AgentMetricsView], _: BenchmarkProble class ProximalCalls(TableMetric): """Number of cost function proximal calls per agent.""" - description: str = "nr proximal calls" + table_description: str = "nr proximal calls" def get_data_from_trial(self, agents: list[AgentMetricsView], _: BenchmarkProblem) -> list[float]: # noqa: D102 return [a.n_proximal_calls for a in agents] @@ -191,7 +203,7 @@ def get_data_from_trial(self, agents: list[AgentMetricsView], _: BenchmarkProble class SentMessages(TableMetric): """Number of sent messages per agent.""" - description: str = "nr sent messages" + table_description: str = "nr sent messages" def get_data_from_trial(self, agents: list[AgentMetricsView], _: BenchmarkProblem) -> list[float]: # noqa: D102 return [a.n_sent_messages for a in agents] @@ -200,7 +212,7 @@ def get_data_from_trial(self, agents: list[AgentMetricsView], _: BenchmarkProble class ReceivedMessages(TableMetric): """Number of received messages per agent.""" - description: str = "nr received messages" + table_description: str = "nr received messages" def get_data_from_trial(self, agents: list[AgentMetricsView], _: BenchmarkProblem) -> list[float]: # noqa: D102 return [a.n_received_messages for a in agents] @@ -209,7 +221,7 @@ def get_data_from_trial(self, agents: list[AgentMetricsView], _: BenchmarkProble class SentMessagesDropped(TableMetric): """Number of sent messages that were dropped per agent.""" - description: str = "nr sent messages dropped" + table_description: str = "nr sent messages dropped" def get_data_from_trial(self, agents: list[AgentMetricsView], _: BenchmarkProblem) -> list[float]: # noqa: D102 return [a.n_sent_messages_dropped for a in agents] @@ -283,30 +295,35 @@ def tabulate( headers = ["Metric (statistic)"] + [alg.name for alg in algs] rows: list[list[str]] = [] statistics_abbr = {"average": "avg", "median": "mdn"} - for metric in metrics: - for statistic in metric.statistics: - row = [f"{metric.description} ({statistics_abbr.get(statistic.__name__) or statistic.__name__})"] - for alg in algs: - agent_states_per_trial = resulting_agent_states[alg] - with warnings.catch_warnings(action="ignore"): - agg_data_per_trial = _aggregate_data_per_trial(agent_states_per_trial, problem, metric, statistic) + with warnings.catch_warnings(action="ignore"), utils.MetricProgressBar() as progress: + n_statistics = sum(len(metric.statistics) for metric in metrics) + table_task = progress.add_task("Generating table", total=n_statistics, status="") + for metric in metrics: + progress.update(table_task, status=f"Task: {metric.table_description}") + data_per_trial = [_data_per_trial(resulting_agent_states[a], problem, metric) for a in algs] + for statistic in metric.statistics: + row = [f"{metric.table_description} ({statistics_abbr.get(statistic.__name__) or statistic.__name__})"] + for i in range(len(algs)): + agg_data_per_trial = [statistic(trial) for trial in data_per_trial[i]] mean, margin_of_error = _calculate_mean_and_margin_of_error(agg_data_per_trial, confidence_level) - formatted_confidence_interval = _format_confidence_interval(mean, margin_of_error) - row.append(formatted_confidence_interval) - rows.append(row) + formatted_confidence_interval = _format_confidence_interval(mean, margin_of_error, metric.fmt) + row.append(formatted_confidence_interval) + rows.append(row) + progress.advance(table_task) + progress.update(table_task, status="Finalizing table") formatted_table = tb.tabulate(rows, headers, tablefmt=table_fmt) LOGGER.info("\n" + formatted_table) -def _aggregate_data_per_trial( - agents_per_trial: list[list[AgentMetricsView]], problem: BenchmarkProblem, metric: TableMetric, statistic: Statistic -) -> list[float]: - aggregated_data_per_trial: list[float] = [] +def _data_per_trial( + agents_per_trial: list[list[AgentMetricsView]], problem: BenchmarkProblem, metric: TableMetric +) -> list[Sequence[float]]: + data_per_trial: list[Sequence[float]] = [] for agents in agents_per_trial: trial_data = metric.get_data_from_trial(agents, problem) - aggregated_trial_data = statistic(trial_data) - aggregated_data_per_trial.append(aggregated_trial_data) - return aggregated_data_per_trial + data_per_trial.append(trial_data) + + return data_per_trial def _calculate_mean_and_margin_of_error(data: list[float], confidence_level: float) -> tuple[float, float]: @@ -317,11 +334,32 @@ def _calculate_mean_and_margin_of_error(data: list[float], confidence_level: flo ) if np.isfinite(mean) and np.isfinite(raw_interval).all(): return (float(mean), float(mean - raw_interval[0])) + return np.nan, np.nan -def _format_confidence_interval(mean: float, margin_of_error: float) -> str: - formatted_confidence_interval = f"{mean:.2e} \u00b1 {margin_of_error:.2e}" +def _format_confidence_interval(mean: float, margin_of_error: float, fmt: str) -> str: + if not _is_valid_float_format_spec(fmt): + LOGGER.warning(f"Invalid format string '{fmt}', defaulting to scientific notation") + fmt = ".2e" + + formatted_confidence_interval = f"{mean:{fmt}} \u00b1 {margin_of_error:{fmt}}" + if any(np.isnan([mean, margin_of_error])): formatted_confidence_interval += " (diverged?)" + return formatted_confidence_interval + + +def _is_valid_float_format_spec(fmt: str) -> bool: + """ + Validate that the given format spec can be used to format a float. + + This avoids attempting to format real values with an invalid format string. + + """ + try: + f"{0.01:{fmt}}" + except (ValueError, TypeError): + return False + return True diff --git a/decent_bench/networks.py b/decent_bench/networks.py index 4eb780b..90fc7ca 100644 --- a/decent_bench/networks.py +++ b/decent_bench/networks.py @@ -479,7 +479,10 @@ def create_distributed_network(problem: BenchmarkProblem) -> P2PNetwork: raise NotImplementedError("Support for multi-graphs has not been implemented yet") if not nx.is_connected(problem.network_structure): raise NotImplementedError("Support for disconnected graphs has not been implemented yet") - agents = [Agent(i, problem.costs[i], problem.agent_activations[i]) for i in range(n_agents)] + agents = [ + Agent(i, problem.costs[i], problem.agent_activations[i], problem.agent_state_snapshot_period) + for i in range(n_agents) + ] agent_node_map = {node: agents[i] for i, node in enumerate(problem.network_structure.nodes())} graph = nx.relabel_nodes(problem.network_structure, agent_node_map) return P2PNetwork( @@ -515,7 +518,10 @@ def create_federated_network(problem: BenchmarkProblem) -> FedNetwork: server, max_degree = max(degrees.items(), key=lambda item: item[1]) # noqa: FURB118 if max_degree != n_agents - 1 or any(deg != 1 for node, deg in degrees.items() if node != server): raise ValueError("Federated network requires a star topology (one server connected to all clients)") - agents = [Agent(i, problem.costs[i], problem.agent_activations[i]) for i in range(n_agents)] + agents = [ + Agent(i, problem.costs[i], problem.agent_activations[i], problem.agent_state_snapshot_period) + for i in range(n_agents) + ] agent_node_map = {node: agents[i] for i, node in enumerate(problem.network_structure.nodes())} graph = nx.relabel_nodes(problem.network_structure, agent_node_map) return FedNetwork( diff --git a/docs/source/_static/plot.png b/docs/source/_static/plot.png index 0e77c24..731ffda 100644 Binary files a/docs/source/_static/plot.png and b/docs/source/_static/plot.png differ diff --git a/docs/source/api/decent_bench.metrics.metric_utils.rst b/docs/source/api/decent_bench.metrics.metric_utils.rst index 6bd20f8..80acd8b 100644 --- a/docs/source/api/decent_bench.metrics.metric_utils.rst +++ b/docs/source/api/decent_bench.metrics.metric_utils.rst @@ -4,4 +4,6 @@ decent\_bench.metrics.metric\_utils .. automodule:: decent_bench.metrics.metric_utils :members: :show-inheritance: - :undoc-members: \ No newline at end of file + :undoc-members: + :exclude-members: + MetricProgressBar, \ No newline at end of file diff --git a/docs/source/api/snippets/computational_cost.rst b/docs/source/api/snippets/computational_cost.rst new file mode 100644 index 0000000..7add96a --- /dev/null +++ b/docs/source/api/snippets/computational_cost.rst @@ -0,0 +1,8 @@ +Computational cost is calculated as: + +.. math:: + \text{Total Cost} = c_f N_f + c_g N_g + c_h N_h + c_p N_p + c_c N_c + +where :math:`c_f, c_g, c_h, c_p, c_c` are the costs per function, gradient, Hessian, proximal, and communication +call respectively, and :math:`N_f, N_g, N_h, N_p, N_c` are the mean number of function, gradient, Hessian, +proximal, and communication calls across all agents and trials. \ No newline at end of file diff --git a/docs/source/developer.rst b/docs/source/developer.rst index dd86af4..c6deaa3 100644 --- a/docs/source/developer.rst +++ b/docs/source/developer.rst @@ -23,7 +23,15 @@ Installation for Development source .tox/dev/bin/activate # activate dev env on Mac/Linux .\.tox\dev\Scripts\activate # activate dev env on Windows +Optionally install development dependencies with proper gpu support, e.g. for PyTorch and TensorFlow: +.. code-block:: + + tox -e dev-gpu + +It is not recommended to use the development environments for regular usage of decent-bench, as they +contain additional packages that are not needed for that purpose. This may cause performance degradation +due to multiple packages competing for resources (e.g. GPU resources). Tooling ------- @@ -157,4 +165,4 @@ Releases 1. Update the version in pyproject.toml using `Semantic Versioning `_. 2. Merge the change into main with commit message :code:`meta: Bump version to .. (#)`. 3. Create a new release on GitHub. -4. Publish to PyPI using :code:`hatch clean && hatch build && hatch publish`. +4. Publish to PyPI using :code:`hatch clean && hatch build && hatch publish`. \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index 67056eb..3172132 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -9,6 +9,10 @@ Welcome to decent-bench! decent-bench allows you to benchmark decentralized optimization algorithms under various communication constraints, providing realistic algorithm comparisons in a user-friendly and highly configurable setting. +Contributions are very welcome, see developer guide on how to get started. Please contact `Dr. Nicola Bastianello `_ +for discussions or start an open discussion at `GitHub `_. +Report any bugs you *may find* to `GitHub `_. + .. toctree:: :maxdepth: 1 diff --git a/docs/source/user.rst b/docs/source/user.rst index d932556..67a6dde 100644 --- a/docs/source/user.rst +++ b/docs/source/user.rst @@ -20,13 +20,12 @@ Benchmark algorithms on a regression problem without any communication constrain from decent_bench import benchmark, benchmark_problem from decent_bench.costs import LinearRegressionCost - from decent_bench.distributed_algorithms import ADMM, DGD, ED + from decent_bench.distributed_algorithms import ADMM, DGD if __name__ == "__main__": benchmark.benchmark( algorithms=[ DGD(iterations=1000, step_size=0.001), - ED(iterations=1000, step_size=0.001), ADMM(iterations=1000, rho=10, alpha=0.3), ], benchmark_problem=benchmark_problem.create_regression_problem(LinearRegressionCost), @@ -39,8 +38,10 @@ Benchmark executions will have outputs like these: * - .. image:: _static/table.png :align: center + :height: 350px - .. image:: _static/plot.png :align: center + :height: 350px Execution settings @@ -64,6 +65,10 @@ Configure settings for metrics, trials, statistical confidence level, logging, a table_metrics=[GradientCalls([min, max])], plot_metrics=[RegretPerIteration()], table_fmt="latex", + computational_cost=pm.ComputationalCost(proximal=2.0, communication=0.1), + compare_iterations_and_computational_cost=True, + plot_grid=False, + plot_path="plots.png", n_trials=10, confidence_level=0.9, log_level=DEBUG, @@ -89,6 +94,7 @@ Configure communication constraints and other settings for out-of-the-box regres problem = benchmark_problem.create_regression_problem( LinearRegressionCost, n_agents=100, + agent_state_snapshot_period=10, # Record metrics every 10 iterations n_neighbors_per_agent=3, asynchrony=True, compression=True, @@ -380,7 +386,7 @@ Create your own metrics to tabulate and/or plot. return float(la.norm(iop.to_numpy(problem.optimal_x) - iop.to_numpy(agent.x_per_iteration[i]))) class XError(TableMetric): - description: str = "x error" + table_description: str = "x error" def get_data_from_trial( self, agents: list[AgentMetricsView], problem: BenchmarkProblem @@ -388,8 +394,7 @@ Create your own metrics to tabulate and/or plot. return [x_error_at_iter(a, problem) for a in agents] class MaxXErrorPerIteration(PlotMetric): - x_label: str = "iteration" - y_label: str = "max x error" + plot_description: str = "max x error" def get_data_from_trial( self, agents: list[AgentMetricsView], problem: BenchmarkProblem diff --git a/test/test_agents.py b/test/test_agents.py new file mode 100644 index 0000000..58f0e73 --- /dev/null +++ b/test/test_agents.py @@ -0,0 +1,264 @@ +import numpy as np +import pytest + +import decent_bench.utils.interoperability as iop +from decent_bench.agents import Agent +from decent_bench.costs import LinearRegressionCost +from decent_bench.utils.types import SupportedDevices, SupportedFrameworks + +try: + import torch + + TORCH_AVAILABLE = True + TORCH_CUDA_AVAILABLE = torch.cuda.is_available() +except ModuleNotFoundError: + TORCH_AVAILABLE = False + TORCH_CUDA_AVAILABLE = False + +try: + import tensorflow as tf + + TF_AVAILABLE = True + TF_GPU_AVAILABLE = len(tf.config.list_physical_devices("GPU")) > 0 +except (ImportError, ModuleNotFoundError): + TF_AVAILABLE = False + TF_GPU_AVAILABLE = False + +try: + import jax + + JAX_AVAILABLE = True + JAX_GPU_AVAILABLE = len(jax.devices("gpu")) > 0 +except (ImportError, ModuleNotFoundError): + JAX_AVAILABLE = False + JAX_GPU_AVAILABLE = False +except RuntimeError: + # JAX raises RuntimeError if no GPU is available when querying devices + JAX_GPU_AVAILABLE = False + + +@pytest.mark.parametrize( + "framework,device", + [ + pytest.param( + SupportedFrameworks.NUMPY, + SupportedDevices.CPU, + ), + pytest.param( + SupportedFrameworks.TORCH, + SupportedDevices.CPU, + marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available"), + ), + pytest.param( + SupportedFrameworks.TORCH, + SupportedDevices.GPU, + marks=pytest.mark.skipif(not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available"), + ), + pytest.param( + SupportedFrameworks.TENSORFLOW, + SupportedDevices.CPU, + marks=pytest.mark.skipif(not TF_AVAILABLE, reason="TensorFlow not available"), + ), + pytest.param( + SupportedFrameworks.TENSORFLOW, + SupportedDevices.GPU, + marks=pytest.mark.skipif(not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available"), + ), + pytest.param( + SupportedFrameworks.JAX, + SupportedDevices.CPU, + marks=pytest.mark.skipif(not JAX_AVAILABLE, reason="JAX not available"), + ), + pytest.param( + SupportedFrameworks.JAX, + SupportedDevices.GPU, + marks=pytest.mark.skipif(not JAX_GPU_AVAILABLE, reason="JAX GPU not available"), + ), + ], +) +def test_in_place_operations_history(framework: SupportedFrameworks, device: SupportedDevices): + """Test that in-place operations on agent.x properly update the history.""" + agent = Agent(0, LinearRegressionCost(np.array([[1.0, 1.0, 1.0]]), np.array([1.0])), None, state_snapshot_period=1) # type: ignore # noqa: PGH003 + + initial = iop.zeros((3,), framework=framework, device=device) + agent.initialize(x=initial) + + def assert_state(expected_x, expected_history): + """Helper to verify agent state and history.""" + np.testing.assert_array_almost_equal( + iop.to_numpy(agent.x), + expected_x, + decimal=5, + err_msg=f"Expected x: {expected_x}, but got: {iop.to_numpy(agent.x)}", + ) + assert len(agent._x_history) == len(expected_history), ( + f"Expected history length: {len(expected_history)}, but got: {len(agent._x_history)}" + ) + for i, expected in enumerate(expected_history): + np.testing.assert_array_almost_equal( + iop.to_numpy(agent._x_history[i]), + expected, + decimal=5, + err_msg=f"At history index {i}, expected: {expected}, but got: {iop.to_numpy(agent._x_history)}", + ) + + # Initial state + assert_state( + np.array([0.0, 0.0, 0.0]), + [ + np.array([0.0, 0.0, 0.0]), + ], + ) + + # Test += operator + agent.x += 1.0 + assert_state( + np.array([1.0, 1.0, 1.0]), + [ + np.array([0.0, 0.0, 0.0]), + np.array([1.0, 1.0, 1.0]), + ], + ) + + # Test *= operator + agent.x *= 2.0 + assert_state( + np.array([2.0, 2.0, 2.0]), + [ + np.array([0.0, 0.0, 0.0]), + np.array([1.0, 1.0, 1.0]), + np.array([2.0, 2.0, 2.0]), + ], + ) + + # Test **= operator + agent.x **= 2.0 + assert_state( + np.array([4.0, 4.0, 4.0]), + [ + np.array([0.0, 0.0, 0.0]), + np.array([1.0, 1.0, 1.0]), + np.array([2.0, 2.0, 2.0]), + np.array([4.0, 4.0, 4.0]), + ], + ) + + # Test /= operator + agent.x /= 2.0 + assert_state( + np.array([2.0, 2.0, 2.0]), + [ + np.array([0.0, 0.0, 0.0]), + np.array([1.0, 1.0, 1.0]), + np.array([2.0, 2.0, 2.0]), + np.array([4.0, 4.0, 4.0]), + np.array([2.0, 2.0, 2.0]), + ], + ) + + # Test -= operator + agent.x -= 1.0 + assert_state( + np.array([1.0, 1.0, 1.0]), + [ + np.array([0.0, 0.0, 0.0]), + np.array([1.0, 1.0, 1.0]), + np.array([2.0, 2.0, 2.0]), + np.array([4.0, 4.0, 4.0]), + np.array([2.0, 2.0, 2.0]), + np.array([1.0, 1.0, 1.0]), + ], + ) + + +@pytest.mark.parametrize( + "framework,device", + [ + pytest.param( + SupportedFrameworks.NUMPY, + SupportedDevices.CPU, + ), + pytest.param( + SupportedFrameworks.TORCH, + SupportedDevices.CPU, + marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available"), + ), + pytest.param( + SupportedFrameworks.TORCH, + SupportedDevices.GPU, + marks=pytest.mark.skipif(not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available"), + ), + pytest.param( + SupportedFrameworks.TENSORFLOW, + SupportedDevices.CPU, + marks=pytest.mark.skipif(not TF_AVAILABLE, reason="TensorFlow not available"), + ), + pytest.param( + SupportedFrameworks.TENSORFLOW, + SupportedDevices.GPU, + marks=pytest.mark.skipif(not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available"), + ), + pytest.param( + SupportedFrameworks.JAX, + SupportedDevices.CPU, + marks=pytest.mark.skipif(not JAX_AVAILABLE, reason="JAX not available"), + ), + pytest.param( + SupportedFrameworks.JAX, + SupportedDevices.GPU, + marks=pytest.mark.skipif(not JAX_GPU_AVAILABLE, reason="JAX GPU not available"), + ), + ], +) +@pytest.mark.parametrize("state_snapshot_period", [1, 5, 10]) +def test_agent_state_snapshot_period(framework: SupportedFrameworks, device: SupportedDevices, state_snapshot_period: int): + """Test that agent history is recorded according to the specified history period.""" + agent = Agent( + 0, + LinearRegressionCost(np.array([[1.0, 1.0, 1.0]]), np.array([1.0])), + None, + state_snapshot_period=state_snapshot_period, + ) + + initial = iop.zeros((3,), framework=framework, device=device) + agent.initialize(x=initial) + + def assert_state(expected_x, expected_history): + """Helper to verify agent state and history.""" + np.testing.assert_array_almost_equal( + iop.to_numpy(agent.x), + expected_x, + decimal=5, + err_msg=f"Expected x: {expected_x}, but got: {iop.to_numpy(agent.x)}", + ) + assert len(agent._x_history) == len(expected_history), ( + f"Expected history length: {len(expected_history)}, but got: {len(agent._x_history)}" + ) + steps = sorted(agent._x_history.keys()) + assert steps == list(range(0, state_snapshot_period * (len(expected_history)), state_snapshot_period)), ( + f"Expected history steps: {list(range(0, state_snapshot_period * (len(expected_history)), state_snapshot_period))}, " + f"but got: {steps}" + ) + for i, expected in zip(steps, expected_history, strict=True): + np.testing.assert_array_almost_equal( + iop.to_numpy(agent._x_history[i]), + expected, + decimal=5, + err_msg=f"At history index {i}, expected: {expected}, but got: {iop.to_numpy(agent._x_history)}", + ) + + expected_history_length = 5 # Excluding the initial state, so +1 later + n_updates = expected_history_length * state_snapshot_period + for _ in range(n_updates): + agent.x += 1.0 + + assert_state( + np.array([n_updates, n_updates, n_updates]), + [ + np.array([0.0, 0.0, 0.0]), + ] + + [ + np.array([i * state_snapshot_period, i * state_snapshot_period, i * state_snapshot_period]) + for i in range(1, expected_history_length + 1) + ], + )