From 1380951d6d2db5adafcf1a286f21862b09f95058 Mon Sep 17 00:00:00 2001 From: Adriana Rodriguez Date: Tue, 9 Dec 2025 18:24:27 +0100 Subject: [PATCH 01/16] feat(networks): add network base and federated support Introduce a shared Network base with common message settings and accessors, refactor P2PNetwork to subclass it, and add FedNetwork with server/client helpers plus star-topology validation and a federated factory. --- decent_bench/networks.py | 318 ++++++++++++++++++++++++++++++++------- 1 file changed, 265 insertions(+), 53 deletions(-) diff --git a/decent_bench/networks.py b/decent_bench/networks.py index 72594d5..d4f3fb5 100644 --- a/decent_bench/networks.py +++ b/decent_bench/networks.py @@ -1,4 +1,7 @@ +from abc import ABC, abstractmethod +from collections.abc import Mapping from functools import cached_property +from operator import itemgetter from typing import TYPE_CHECKING import networkx as nx @@ -17,17 +20,8 @@ AgentGraph = Graph -class P2PNetwork: - """ - Peer-to-Peer Network of agents that communicate by sending and receiving messages. - - Args: - graph: topology defining how the agents are connected - message_noise: message noise setting - message_compression: message compression setting - message_drop: message drops setting - - """ +class Network(ABC): + """Base network that defines the communication constraints shared by all network types.""" def __init__( self, @@ -35,39 +29,40 @@ def __init__( message_noise: NoiseScheme, message_compression: CompressionScheme, message_drop: DropScheme, - ): + ) -> None: self._graph = graph self._message_noise = message_noise self._message_compression = message_compression self._message_drop = message_drop - @cached_property - def weights(self) -> NDArray[float64]: - """ - Symmetric, doubly stochastic matrix for consensus weights. Initialized using the Metropolis-Hastings method. + @property + def graph(self) -> AgentGraph: + """Underlying agent graph.""" + return self._graph - Use ``weights[i, j]`` or ``weights[i.id, j.id]`` to get the weight between agent i and j. - """ - agents = self.agents() - n = len(agents) - W = np.zeros((n, n)) # noqa: N806 - for i in agents: - neighbors = self.neighbors(i) - d_i = len(neighbors) - for j in neighbors: - d_j = len(self.neighbors(j)) - W[i, j] = 1 / (1 + max(d_i, d_j)) - for i in agents: - W[i, i] = 1 - sum(W[i]) - return W + @property + def message_noise(self) -> NoiseScheme: + """Noise scheme applied to messages.""" + return self._message_noise + + @property + def message_compression(self) -> CompressionScheme: + """Compression scheme applied to messages.""" + return self._message_compression + + @property + def message_drop(self) -> DropScheme: + """Drop scheme applied to messages.""" + return self._message_drop + + @abstractmethod + def kind(self) -> str: + """Label for the network subtype (e.g., 'p2p', 'fed').""" + raise NotImplementedError def agents(self) -> list[Agent]: """Get all agents in the network.""" - return list(self._graph) - - def neighbors(self, agent: Agent) -> list[Agent]: - """Get all neighbors of an agent.""" - return list(self._graph[agent]) + return list(self.graph) def active_agents(self, iteration: int) -> list[Agent]: """ @@ -91,12 +86,71 @@ def send(self, sender: Agent, receiver: Agent, msg: NDArray[float64]) -> None: same receiver. After being received or replaced, the message is destroyed. """ sender._n_sent_messages += 1 # noqa: SLF001 - if self._message_drop.should_drop(): + if self.message_drop.should_drop(): sender._n_sent_messages_dropped += 1 # noqa: SLF001 return - msg = self._message_compression.compress(msg) - msg = self._message_noise.make_noise(msg) - self._graph.edges[sender, receiver][str(receiver.id)] = msg + msg = self.message_compression.compress(msg) + msg = self.message_noise.make_noise(msg) + self.graph.edges[sender, receiver][str(receiver.id)] = msg + + def receive(self, receiver: Agent, sender: Agent) -> None: + """ + Receive message from a neighbor. + + Received messages are stored in + :attr:`Agent.messages `. + """ + msg = self.graph.edges[sender, receiver].get(str(receiver.id)) + if msg is not None: + receiver._n_received_messages += 1 # noqa: SLF001 + receiver._received_messages[sender] = msg # noqa: SLF001 + self.graph.edges[sender, receiver][str(receiver.id)] = None + + +class P2PNetwork(Network): + """Peer-to-peer network of agents that communicate by sending and receiving messages.""" + + def __init__( + self, + graph: AgentGraph, + message_noise: NoiseScheme, + message_compression: CompressionScheme, + message_drop: DropScheme, + ) -> None: + super().__init__( + graph=graph, + message_noise=message_noise, + message_compression=message_compression, + message_drop=message_drop, + ) + + def kind(self) -> str: + """Label for the network subtype.""" + return "p2p" + + @cached_property + def weights(self) -> NDArray[float64]: + """ + Symmetric, doubly stochastic matrix for consensus weights. Initialized using the Metropolis-Hastings method. + + Use ``weights[i, j]`` or ``weights[i.id, j.id]`` to get the weight between agent i and j. + """ + agents = self.agents() + n = len(agents) + W = np.zeros((n, n)) # noqa: N806 + for i in agents: + neighbors = self.neighbors(i) + d_i = len(neighbors) + for j in neighbors: + d_j = len(self.neighbors(j)) + W[i, j] = 1 / (1 + max(d_i, d_j)) + for i in agents: + W[i, i] = 1 - sum(W[i]) + return W + + def neighbors(self, agent: Agent) -> list[Agent]: + """Get all neighbors of an agent.""" + return list(self.graph[agent]) def broadcast(self, sender: Agent, msg: NDArray[float64]) -> None: """ @@ -110,31 +164,153 @@ def broadcast(self, sender: Agent, msg: NDArray[float64]) -> None: The message will stay in-flight until it is received or replaced by a newer message from the same sender to the same receiver. After being received or replaced, the message is destroyed. """ - for neighbor in self._graph.neighbors(sender): + for neighbor in self.neighbors(sender): self.send(sender=sender, receiver=neighbor, msg=msg) - def receive(self, receiver: Agent, sender: Agent) -> None: + def receive_all(self, receiver: Agent) -> None: """ - Receive message from a neighbor. + Receive messages from all neighbors. Received messages are stored in :attr:`Agent.messages `. """ - msg = self._graph.edges[sender, receiver].get(str(receiver.id)) - if msg is not None: - receiver._n_received_messages += 1 # noqa: SLF001 - receiver._received_messages[sender] = msg # noqa: SLF001 - self._graph.edges[sender, receiver][str(receiver.id)] = None + for neighbor in self.neighbors(receiver): + self.receive(receiver, neighbor) - def receive_all(self, receiver: Agent) -> None: + +class FedNetwork(Network): + """Federated learning network with one server node connected to all client nodes (star topology).""" + + def __init__( + self, + graph: AgentGraph, + message_noise: NoiseScheme, + message_compression: CompressionScheme, + message_drop: DropScheme, + ) -> None: + super().__init__( + graph=graph, + message_noise=message_noise, + message_compression=message_compression, + message_drop=message_drop, + ) + self._server = self._identify_server() + + def _identify_server(self) -> Agent: + degrees = dict(self.graph.degree()) + if not degrees: + raise ValueError("FedNetwork requires at least one agent") + server, max_degree = max(degrees.items(), key=itemgetter(1)) + n = len(degrees) + if max_degree != n - 1 or any(deg != 1 for node, deg in degrees.items() if node != server): + raise ValueError("FedNetwork expects a star topology with one server connected to all clients") + return server + + def kind(self) -> str: + """Label for the network subtype.""" + return "fed" + + @property + def server(self) -> Agent: + """Agent acting as the central server.""" + return self._server + + @property + def clients(self) -> list[Agent]: + """Agents acting as clients.""" + return [agent for agent in self.graph if agent is not self.server] + + def active_clients(self, iteration: int) -> list[Agent]: """ - Receive messages from all neighbors. + Get all active clients (excludes the server). - Received messages are stored in - :attr:`Agent.messages `. + Uses :meth:`Network.active_agents` to honor activation schemes and then filters out the server. """ - for neighbor in self._graph.neighbors(receiver): - self.receive(receiver, neighbor) + return [agent for agent in self.active_agents(iteration) if agent is not self.server] + + def send_to_client(self, client: Agent, msg: NDArray[float64]) -> None: + """ + Send a message from the server to a specific client. + + Raises: + ValueError: if the receiver is not a client. + + """ + if client not in self.clients: + raise ValueError("Receiver must be a client") + self.send(sender=self.server, receiver=client, msg=msg) + + def send_to_all_clients(self, msg: NDArray[float64]) -> None: + """Send the same message from the server to every client (synchronous FL push).""" + for client in self.clients: + self.send_to_client(client, msg) + + def send_from_client(self, client: Agent, msg: NDArray[float64]) -> None: + """ + Send a message from a client to the server. + + Raises: + ValueError: if the sender is not a client. + + """ + if client not in self.clients: + raise ValueError("Sender must be a client") + self.send(sender=client, receiver=self.server, msg=msg) + + def send_from_all_clients(self, msgs: Mapping[Agent, NDArray[float64]]) -> None: + """ + Send messages from each client to the server (synchronous FL push). + + Args: + msgs: mapping from client Agent to the message that client should send. Must include all clients. + + Raises: + ValueError: if any sender is not a client or if any client is missing. + + """ + clients = set(self.clients) + invalid = [client for client in msgs if client not in clients] + if invalid: + raise ValueError("All senders must be clients") + missing = clients - set(msgs) + if missing: + raise ValueError("Messages must be provided for all clients") + for client, msg in msgs.items(): + self.send_from_client(client, msg) + + def receive_at_client(self, client: Agent) -> None: + """ + Receive a message at a client from the server. + + Raises: + ValueError: if the receiver is not a client. + + """ + if client not in self.clients: + raise ValueError("Receiver must be a client") + self.receive(receiver=client, sender=self.server) + + def receive_at_all_clients(self) -> None: + """Receive messages at every client from the server (synchronous FL pull).""" + for client in self.clients: + self.receive_at_client(client) + + def receive_from_client(self, client: Agent) -> None: + """ + Receive a message at the server from a specific client. + + Raises: + ValueError: if the sender is not a client. + + """ + if client not in self.clients: + raise ValueError("Sender must be a client") + self.receive(receiver=self.server, sender=client) + + def receive_from_all_clients(self) -> None: + """Receive messages at the server from every client (synchronous FL pull).""" + for client in self.clients: + self.receive_from_client(client) def create_distributed_network(problem: BenchmarkProblem) -> P2PNetwork: @@ -165,3 +341,39 @@ def create_distributed_network(problem: BenchmarkProblem) -> P2PNetwork: message_compression=problem.message_compression, message_drop=problem.message_drop, ) + + +def create_federated_network(problem: BenchmarkProblem) -> FedNetwork: + """ + Create a federated learning network with a single server and multiple clients (star topology). + + Raises: + ValueError: if there are fewer activation schemes or cost functions than agents + ValueError: if the provided graph is not a star (one server connected to all clients) + + """ + n_agents = len(problem.network_structure) + if len(problem.agent_activations) < n_agents: + raise ValueError("Insufficient number of agent activation schemes, please provide one per agent") + if len(problem.costs) < n_agents: + raise ValueError("Insufficient number of cost functions, please provide one per agent") + if problem.network_structure.is_directed(): + raise NotImplementedError("Support for directed graphs has not been implemented yet") + if problem.network_structure.is_multigraph(): + raise NotImplementedError("Support for multi-graphs has not been implemented yet") + if not nx.is_connected(problem.network_structure): + raise NotImplementedError("Support for disconnected graphs has not been implemented yet") + degrees = dict(problem.network_structure.degree()) + if n_agents: + server, max_degree = max(degrees.items(), key=itemgetter(1)) + if max_degree != n_agents - 1 or any(deg != 1 for node, deg in degrees.items() if node != server): + raise ValueError("Federated network requires a star topology (one server connected to all clients)") + agents = [Agent(i, problem.costs[i], problem.agent_activations[i]) for i in range(n_agents)] + agent_node_map = {node: agents[i] for i, node in enumerate(problem.network_structure.nodes())} + graph = nx.relabel_nodes(problem.network_structure, agent_node_map) + return FedNetwork( + graph=graph, + message_noise=problem.message_noise, + message_compression=problem.message_compression, + message_drop=problem.message_drop, + ) From 9a669db7dcc914115009f42a2139e685281f59bf Mon Sep 17 00:00:00 2001 From: Adriana Rodriguez Date: Wed, 10 Dec 2025 19:39:19 +0100 Subject: [PATCH 02/16] wip(fl-setup): save network changes after merge (Array) --- decent_bench/agents.py | 43 +- decent_bench/benchmark_problem.py | 19 +- decent_bench/centralized_algorithms.py | 34 +- decent_bench/costs.py | 205 ++- decent_bench/datasets.py | 31 +- decent_bench/distributed_algorithms.py | 137 +- decent_bench/metrics/metric_utils.py | 12 +- decent_bench/metrics/table_metrics.py | 3 +- decent_bench/networks.py | 67 +- decent_bench/schemes.py | 29 +- decent_bench/utils/array.py | 347 ++++++ decent_bench/utils/interoperability.py | 690 ----------- .../utils/interoperability/__init__.py | 102 ++ .../utils/interoperability/_decorators.py | 110 ++ decent_bench/utils/interoperability/_ext.py | 15 + .../utils/interoperability/_functions.py | 936 ++++++++++++++ .../utils/interoperability/_helpers.py | 91 ++ .../utils/interoperability/_imports_types.py | 47 + .../utils/interoperability/_operators.py | 483 ++++++++ decent_bench/utils/types.py | 46 + docs/source/api/decent_bench.costs.rst | 4 +- docs/source/api/decent_bench.utils.array.rst | 7 + .../decent_bench.utils.interoperability.rst | 4 +- docs/source/api/decent_bench.utils.rst | 4 +- docs/source/api/decent_bench.utils.types.rst | 7 + .../source/api/snippets/proximal_operator.rst | 4 +- docs/source/conf.py | 20 +- docs/source/developer.rst | 4 +- docs/source/user.rst | 136 +- pyproject.toml | 21 +- test/utils/test_array.py | 468 +++++++ test/utils/test_interoperability.py | 1098 +++++++++++------ 32 files changed, 3905 insertions(+), 1319 deletions(-) create mode 100644 decent_bench/utils/array.py delete mode 100644 decent_bench/utils/interoperability.py create mode 100644 decent_bench/utils/interoperability/__init__.py create mode 100644 decent_bench/utils/interoperability/_decorators.py create mode 100644 decent_bench/utils/interoperability/_ext.py create mode 100644 decent_bench/utils/interoperability/_functions.py create mode 100644 decent_bench/utils/interoperability/_helpers.py create mode 100644 decent_bench/utils/interoperability/_imports_types.py create mode 100644 decent_bench/utils/interoperability/_operators.py create mode 100644 decent_bench/utils/types.py create mode 100644 docs/source/api/decent_bench.utils.array.rst create mode 100644 docs/source/api/decent_bench.utils.types.rst create mode 100644 test/utils/test_array.py diff --git a/decent_bench/agents.py b/decent_bench/agents.py index 332a489..8dcfd9c 100644 --- a/decent_bench/agents.py +++ b/decent_bench/agents.py @@ -4,11 +4,9 @@ from dataclasses import dataclass from types import MappingProxyType -from numpy import float64 -from numpy.typing import NDArray - from decent_bench.costs import Cost from decent_bench.schemes import AgentActivationScheme +from decent_bench.utils.array import Array class Agent: @@ -18,9 +16,9 @@ def __init__(self, agent_id: int, cost: Cost, activation: AgentActivationScheme) self._id = agent_id self._cost = cost self._activation = activation - self._x_history: list[NDArray[float64]] = [] - self._auxiliary_variables: dict[str, NDArray[float64]] = {} - self._received_messages: dict[Agent, NDArray[float64]] = {} + self._x_history: list[Array] = [] + self._auxiliary_variables: dict[str, Array] = {} + self._received_messages: dict[Agent, Array] = {} self._n_sent_messages = 0 self._n_received_messages = 0 self._n_sent_messages_dropped = 0 @@ -52,10 +50,17 @@ def cost(self) -> Cost: loss = cost @property - def x(self) -> NDArray[float64]: + def x(self) -> Array: """ Local optimization variable x. + Warning: + Do not use in-place operations (``+=``, ``-=``, ``*=``, etc.) on this property. + In-place operations will corrupt the optimization history by modifying all + historical values. Always use ``agent.x = agent.x + value`` instead of + ``agent.x += value``. Does not affect the outcome of the optimization, but + will affect logging and metrics that depend on the optimization history. + Raises: RuntimeError: if x is retrieved before being set or initialized @@ -65,25 +70,25 @@ def x(self) -> NDArray[float64]: return self._x_history[-1] @x.setter - def x(self, x: NDArray[float64]) -> None: + def x(self, x: Array) -> None: self._x_history.append(x) @property - def messages(self) -> Mapping[Agent, NDArray[float64]]: + def messages(self) -> Mapping[Agent, Array]: """Messages received by neighbors.""" return MappingProxyType(self._received_messages) @property - def aux_vars(self) -> dict[str, NDArray[float64]]: + def aux_vars(self) -> dict[str, Array]: """Auxiliary optimization variables used by algorithms that require more variables than x.""" return self._auxiliary_variables def initialize( self, *, - x: NDArray[float64] | None = None, - aux_vars: dict[str, NDArray[float64]] | None = None, - received_msgs: dict[Agent, NDArray[float64]] | None = None, + x: Array | None = None, + aux_vars: dict[str, Array] | None = None, + received_msgs: dict[Agent, Array] | None = None, ) -> None: """ Initialize local variables and messages before running an algorithm. @@ -101,21 +106,21 @@ def initialize( if received_msgs: self._received_messages = received_msgs - def _call_counting_function(self, x: NDArray[float64]) -> float: + def _call_counting_function(self, x: Array) -> float: self._n_function_calls += 1 return self._cost.__class__.function(self.cost, x) - def _call_counting_gradient(self, x: NDArray[float64]) -> NDArray[float64]: + def _call_counting_gradient(self, x: Array) -> Array: self._n_gradient_calls += 1 return self._cost.__class__.gradient(self.cost, x) - def _call_counting_hessian(self, x: NDArray[float64]) -> NDArray[float64]: + def _call_counting_hessian(self, x: Array) -> Array: self._n_hessian_calls += 1 return self._cost.__class__.hessian(self.cost, x) - def _call_counting_proximal(self, y: NDArray[float64], rho: float) -> NDArray[float64]: + def _call_counting_proximal(self, x: Array, rho: float) -> Array: self._n_proximal_calls += 1 - return self._cost.__class__.proximal(self.cost, y, rho) + return self._cost.__class__.proximal(self.cost, x, rho) def __index__(self) -> int: """Enable using agent as index, for example ``W[a1, a2]`` instead of ``W[a1.id, a2.id]``.""" @@ -127,7 +132,7 @@ class AgentMetricsView: """Immutable view of agent that exposes useful properties for calculating metrics.""" cost: Cost - x_history: list[NDArray[float64]] + x_history: list[Array] n_function_calls: int n_gradient_calls: int n_hessian_calls: int diff --git a/decent_bench/benchmark_problem.py b/decent_bench/benchmark_problem.py index bc2fb26..ad7377a 100644 --- a/decent_bench/benchmark_problem.py +++ b/decent_bench/benchmark_problem.py @@ -5,9 +5,6 @@ from typing import TYPE_CHECKING, Any import networkx as nx -from networkx import Graph -from numpy import float64 -from numpy.typing import NDArray import decent_bench.centralized_algorithms as ca from decent_bench.costs import Cost, LinearRegressionCost, LogisticRegressionCost @@ -26,11 +23,13 @@ UniformActivationRate, UniformDropRate, ) +from decent_bench.utils.array import Array +from decent_bench.utils.types import SupportedDevices, SupportedFrameworks if TYPE_CHECKING: - AnyGraph = Graph[Any] + AnyGraph = nx.Graph[Any] else: - AnyGraph = Graph + AnyGraph = nx.Graph @dataclass(eq=False) @@ -50,7 +49,7 @@ class BenchmarkProblem: """ network_structure: AnyGraph - x_optimal: NDArray[float64] + x_optimal: Array costs: Sequence[Cost] agent_activations: Sequence[AgentActivationScheme] message_compression: CompressionScheme @@ -83,7 +82,13 @@ def create_regression_problem( """ network_structure = nx.random_regular_graph(n_neighbors_per_agent, n_agents, seed=0) dataset = SyntheticClassificationData( - n_classes=2, n_partitions=n_agents, n_samples_per_partition=10, n_features=3, seed=0 + n_classes=2, + n_partitions=n_agents, + n_samples_per_partition=10, + n_features=3, + framework=SupportedFrameworks.NUMPY, + device=SupportedDevices.CPU, + seed=0, ) costs = [cost_cls(*p) for p in dataset.training_partitions()] sum_cost = reduce(add, costs) diff --git a/decent_bench/centralized_algorithms.py b/decent_bench/centralized_algorithms.py index 29fdd6d..74c2dd3 100644 --- a/decent_bench/centralized_algorithms.py +++ b/decent_bench/centralized_algorithms.py @@ -1,9 +1,9 @@ from typing import TYPE_CHECKING import numpy as np -from numpy import float64 -from numpy import linalg as la -from numpy.typing import NDArray + +import decent_bench.utils.interoperability as iop +from decent_bench.utils.array import Array if TYPE_CHECKING: from decent_bench.costs import Cost @@ -11,19 +11,19 @@ def gradient_descent( cost: "Cost", - x0: NDArray[float64] | None, + x0: Array | None, *, step_size: float, max_iter: int, stop_tol: float | None, max_tol: float | None, -) -> NDArray[float64]: +) -> Array: """ Find the x that minimizes the cost function using gradient descent. Args: cost: cost function to minimize - x0: initial guess, defaults to ``np.zeros()`` if ``None`` is provided + x0: initial guess, defaults to ``iop.zeros()`` if ``None`` is provided step_size: scaling factor for each update max_iter: maximum number of iterations to run stop_tol: early stopping criteria - stop if ``norm(x_new - x) <= stop_tol`` @@ -37,10 +37,10 @@ def gradient_descent( """ delta = np.inf - x = x0 if x0 is not None else np.zeros(cost.shape) + x = x0 if x0 is not None else iop.zeros(shape=cost.shape, framework=cost.framework, device=cost.device) for _ in range(max_iter): x_new = x - step_size * cost.gradient(x) - delta = float(la.norm(x_new - x)) + delta = float(iop.norm(x_new - x)) x = x_new if stop_tol is not None and delta <= stop_tol: break @@ -55,18 +55,18 @@ def gradient_descent( def accelerated_gradient_descent( cost: "Cost", - x0: NDArray[float64] | None, + x0: Array | None, *, max_iter: int, stop_tol: float | None, max_tol: float | None, -) -> NDArray[float64]: +) -> Array: r""" Find the x that minimizes the cost function using accelerated gradient descent. Args: cost: cost function to minimize - x0: initial guess, defaults to ``np.zeros()`` if ``None`` is provided + x0: initial guess, defaults to ``iop.zeros()`` if ``None`` is provided max_iter: maximum number of iterations to run stop_tol: early stopping criteria - stop if ``norm(x_new - x) <= stop_tol`` max_tol: maximum tolerated ``norm(x_new - x)`` at the end @@ -79,7 +79,7 @@ def accelerated_gradient_descent( x that minimizes the cost function. """ - if x0 is not None and x0.shape != cost.shape: + if x0 is not None and iop.shape(x0) != cost.shape: raise ValueError("x0 and cost function domain must have same shape") if cost.m_smooth == 0: raise ValueError("Function must not be affine") @@ -93,14 +93,14 @@ def accelerated_gradient_descent( raise NotImplementedError("Support for non-global differentiability is not implemented yet") if np.isnan(cost.m_cvx): raise NotImplementedError("Support for non-convexity is not implemented yet") - x0 = x0 if x0 is not None else np.zeros(cost.shape) + x0 = x0 if x0 is not None else iop.zeros(shape=cost.shape, framework=cost.framework, device=cost.device) x = x0 y = x0 c = (np.sqrt(cost.m_smooth) - np.sqrt(cost.m_cvx)) / (np.sqrt(cost.m_smooth) + np.sqrt(cost.m_cvx)) delta = np.inf for k in range(1, max_iter + 1): x_new = y - cost.gradient(y) / cost.m_smooth - delta = float(la.norm(x_new - x)) + delta = float(iop.norm(x_new - x)) beta = c if cost.m_cvx > 0 else (k - 1) / (k + 2) y_new = x_new + beta * (x_new - x) x, y = x_new, y_new @@ -115,7 +115,7 @@ def accelerated_gradient_descent( return x -def proximal_solver(cost: "Cost", y: NDArray[float64], rho: float) -> NDArray[float64]: +def proximal_solver(cost: "Cost", y: Array, rho: float) -> Array: """ Find the proximal at y using accelerated gradient descent. @@ -127,11 +127,11 @@ def proximal_solver(cost: "Cost", y: NDArray[float64], rho: float) -> NDArray[fl ValueError: if *cost*'s domain and *y* don't have the same shape, or if *rho* is not greater than 0 """ - if cost.shape != y.shape: + if cost.shape != iop.shape(y): raise ValueError("Cost function domain and y need to have the same shape") if rho <= 0: raise ValueError("Penalty term `rho` must be greater than 0") from decent_bench.costs import QuadraticCost # noqa: PLC0415 - proximal_cost = QuadraticCost(A=np.eye(len(y)) / rho, b=-y / rho, c=y.dot(y) / (2 * rho)) + cost + proximal_cost = QuadraticCost(A=iop.eye_like(y) / rho, b=-y / rho, c=float(iop.dot(y, y)) / (2 * rho)) + cost return accelerated_gradient_descent(proximal_cost, y, max_iter=100, stop_tol=1e-10, max_tol=None) diff --git a/decent_bench/costs.py b/decent_bench/costs.py index b4498a3..9e7a36a 100644 --- a/decent_bench/costs.py +++ b/decent_bench/costs.py @@ -1,6 +1,5 @@ from __future__ import annotations -import copy from abc import ABC, abstractmethod from functools import cached_property @@ -11,6 +10,9 @@ from scipy import special import decent_bench.centralized_algorithms as ca +import decent_bench.utils.interoperability as iop +from decent_bench.utils.array import Array +from decent_bench.utils.types import SupportedDevices, SupportedFrameworks class Cost(ABC): @@ -26,6 +28,28 @@ def domain_shape(self) -> tuple[int, ...]: """Alias for :attr:`shape`.""" return self.shape + @property + @abstractmethod + def framework(self) -> SupportedFrameworks: + """ + The framework used by this cost function. + + Make sure that all :class:`decent_bench.utils.array.Array` objects returned by this cost function's methods + use this framework. + + """ + + @property + @abstractmethod + def device(self) -> SupportedDevices: + """ + The device used by this cost function. + + Make sure that all :class:`decent_bench.utils.array.Array` objects returned by this cost function's methods + use this device. + + """ + @property @abstractmethod def m_smooth(self) -> float: @@ -64,33 +88,33 @@ def m_cvx(self) -> float: """ @abstractmethod - def function(self, x: NDArray[float64]) -> float: + def function(self, x: Array) -> float: """Evaluate function at x.""" - def evaluate(self, x: NDArray[float64]) -> float: + def evaluate(self, x: Array) -> float: """Alias for :meth:`function`.""" return self.function(x) - def loss(self, x: NDArray[float64]) -> float: + def loss(self, x: Array) -> float: """Alias for :meth:`function`.""" return self.function(x) - def f(self, x: NDArray[float64]) -> float: + def f(self, x: Array) -> float: """Alias for :meth:`function`.""" return self.function(x) @abstractmethod - def gradient(self, x: NDArray[float64]) -> NDArray[float64]: + def gradient(self, x: Array) -> Array: """Gradient at x.""" @abstractmethod - def hessian(self, x: NDArray[float64]) -> NDArray[float64]: + def hessian(self, x: Array) -> Array: """Hessian at x.""" @abstractmethod - def proximal(self, y: NDArray[float64], rho: float) -> NDArray[float64]: + def proximal(self, x: Array, rho: float) -> Array: r""" - Proximal at y. + Proximal at x. The proximal operator is defined as: @@ -121,26 +145,36 @@ class QuadraticCost(Cost): .. math:: f(\mathbf{x}) = \frac{1}{2} \mathbf{x}^T \mathbf{Ax} + \mathbf{b}^T \mathbf{x} + c """ - def __init__(self, A: NDArray[float64], b: NDArray[float64], c: float): # noqa: N803 - if A.ndim != 2: + def __init__(self, A: Array, b: Array, c: float): # noqa: N803 + self.A: NDArray[float64] = iop.to_numpy(A) + self.b: NDArray[float64] = iop.to_numpy(b) + + if self.A.ndim != 2: raise ValueError("Matrix A must be 2D") - if A.shape[0] != A.shape[1]: + if self.A.shape[0] != self.A.shape[1]: raise ValueError("Matrix A must be square") - if b.ndim != 1: + if self.b.ndim != 1: raise ValueError("Vector b must be 1D") - if A.shape[0] != b.shape[0]: - raise ValueError(f"Dimension mismatch: A has shape {A.shape} but b has length {b.shape[0]}") - self.A = A - self.A_sym = 0.5 * (A + A.T) - self.b = b + if self.A.shape[0] != self.b.shape[0]: + raise ValueError(f"Dimension mismatch: A has shape {self.A.shape} but b has length {self.b.shape[0]}") + + self.A_sym = 0.5 * (self.A + self.A.T) self.c = c @property def shape(self) -> tuple[int, ...]: # noqa: D102 return self.b.shape + @property + def framework(self) -> SupportedFrameworks: # noqa: D102 + return SupportedFrameworks.NUMPY + + @property + def device(self) -> SupportedDevices: # noqa: D102 + return SupportedDevices.CPU + @cached_property - def m_smooth(self) -> float: + def m_smooth(self) -> float: # pyright: ignore[reportIncompatibleMethodOverride] r""" The cost function's smoothness constant. @@ -156,7 +190,7 @@ def m_smooth(self) -> float: return float(np.max(np.abs(eigs))) @cached_property - def m_cvx(self) -> float: + def m_cvx(self) -> float: # pyright: ignore[reportIncompatibleMethodOverride] r""" The cost function's convexity constant. @@ -181,6 +215,7 @@ def m_cvx(self) -> float: return 0 return np.nan + @iop.autodecorate_cost_method(Cost.function) def function(self, x: NDArray[float64]) -> float: r""" Evaluate function at x. @@ -189,6 +224,7 @@ def function(self, x: NDArray[float64]) -> float: """ return float(0.5 * x.dot(self.A.dot(x)) + self.b.dot(x) + self.c) + @iop.autodecorate_cost_method(Cost.gradient) def gradient(self, x: NDArray[float64]) -> NDArray[float64]: r""" Gradient at x. @@ -197,20 +233,23 @@ def gradient(self, x: NDArray[float64]) -> NDArray[float64]: """ return self.A_sym @ x + self.b + @iop.autodecorate_cost_method(Cost.hessian) def hessian(self, x: NDArray[float64]) -> NDArray[float64]: # noqa: ARG002 r""" Hessian at x. .. math:: \frac{1}{2} (\mathbf{A}+\mathbf{A}^T) """ - return self.A_sym + ret: NDArray[float64] = self.A_sym.copy() + return ret - def proximal(self, y: NDArray[float64], rho: float) -> NDArray[float64]: + @iop.autodecorate_cost_method(Cost.proximal) + def proximal(self, x: NDArray[float64], rho: float) -> NDArray[float64]: r""" - Proximal at y. + Proximal at x. .. math:: - (\frac{\rho}{2} (\mathbf{A} + \mathbf{A}^T) + \mathbf{I})^{-1} (\mathbf{y} - \rho \mathbf{b}) + (\frac{\rho}{2} (\mathbf{A} + \mathbf{A}^T) + \mathbf{I})^{-1} (\mathbf{x} - \rho \mathbf{b}) where :math:`\rho > 0` is the penalty. @@ -219,7 +258,8 @@ def proximal(self, y: NDArray[float64], rho: float) -> NDArray[float64]: for the general proximal definition. """ lhs = rho * self.A_sym + np.eye(self.A.shape[1]) - rhs = y - self.b * rho + rhs = x - self.b * rho + return np.asarray(np.linalg.solve(lhs, rhs), dtype=float64) def __add__(self, other: Cost) -> Cost: @@ -233,7 +273,11 @@ def __add__(self, other: Cost) -> Cost: if self.shape != other.shape: raise ValueError(f"Mismatching domain shapes: {self.shape} vs {other.shape}") if isinstance(other, QuadraticCost): - return QuadraticCost(self.A + other.A, self.b + other.b, self.c + other.c) + return QuadraticCost( + iop.to_array(self.A + other.A, self.framework, self.device), + iop.to_array(self.b + other.b, self.framework, self.device), + self.c + other.c, + ) if isinstance(other, LinearRegressionCost): return self + other.inner return SumCost([self, other]) @@ -254,10 +298,12 @@ class LinearRegressionCost(Cost): + \frac{1}{2} \mathbf{b}^T\mathbf{b} """ - def __init__(self, A: NDArray[float64], b: NDArray[float64]): # noqa: N803 - if A.shape[0] != b.shape[0]: - raise ValueError(f"Dimension mismatch: A has {A.shape[0]} rows but b has {b.shape[0]} elements") - self.inner = QuadraticCost(A.T.dot(A), -A.T.dot(b), 0.5 * b.dot(b)) + def __init__(self, A: Array, b: Array): # noqa: N803 + if iop.shape(A)[0] != iop.shape(b)[0]: + raise ValueError(f"Dimension mismatch: A has {iop.shape(A)[0]} rows but b has {iop.shape(b)[0]} elements") + self.inner = QuadraticCost( + iop.dot(iop.transpose(A), A), -iop.dot(iop.transpose(A), b), float(0.5 * iop.dot(b, b)) + ) self.A = A self.b = b @@ -265,6 +311,14 @@ def __init__(self, A: NDArray[float64], b: NDArray[float64]): # noqa: N803 def shape(self) -> tuple[int, ...]: # noqa: D102 return self.inner.shape + @property + def framework(self) -> SupportedFrameworks: # noqa: D102 + return SupportedFrameworks.NUMPY + + @property + def device(self) -> SupportedDevices: # noqa: D102 + return SupportedDevices.CPU + @property def m_smooth(self) -> float: r""" @@ -299,7 +353,7 @@ def m_cvx(self) -> float: """ return self.inner.m_cvx - def function(self, x: NDArray[float64]) -> float: + def function(self, x: Array) -> float: r""" Evaluate function at x. @@ -307,7 +361,7 @@ def function(self, x: NDArray[float64]) -> float: """ return self.inner.function(x) - def gradient(self, x: NDArray[float64]) -> NDArray[float64]: + def gradient(self, x: Array) -> Array: r""" Gradient at x. @@ -315,7 +369,7 @@ def gradient(self, x: NDArray[float64]) -> NDArray[float64]: """ return self.inner.gradient(x) - def hessian(self, x: NDArray[float64]) -> NDArray[float64]: + def hessian(self, x: Array) -> Array: r""" Hessian at x. @@ -323,12 +377,12 @@ def hessian(self, x: NDArray[float64]) -> NDArray[float64]: """ return self.inner.hessian(x) - def proximal(self, y: NDArray[float64], rho: float) -> NDArray[float64]: + def proximal(self, x: Array, rho: float) -> Array: r""" - Proximal at y. + Proximal at x. .. math:: - (\rho \mathbf{A}^T \mathbf{A} + \mathbf{I})^{-1} (\mathbf{y} + \rho \mathbf{A}^T\mathbf{b}) + (\rho \mathbf{A}^T \mathbf{A} + \mathbf{I})^{-1} (\mathbf{x} + \rho \mathbf{A}^T\mathbf{b}) where :math:`\rho > 0` is the penalty. @@ -336,7 +390,7 @@ def proximal(self, y: NDArray[float64], rho: float) -> NDArray[float64]: :meth:`Cost.proximal() ` for the general proximal definition. """ - return self.inner.proximal(y, rho) + return self.inner.proximal(x, rho) def __add__(self, other: Cost) -> Cost: """Add another cost function.""" @@ -353,27 +407,35 @@ class LogisticRegressionCost(Cost): \log( 1 - \sigma(\mathbf{Ax}) ) \right] """ - def __init__(self, A: NDArray[float64], b: NDArray[float64]): # noqa: N803 - if A.ndim != 2: + def __init__(self, A: Array, b: Array): # noqa: N803 + if len(iop.shape(A)) != 2: raise ValueError("Matrix A must be 2D") - if b.ndim != 1: + if len(iop.shape(b)) != 1: raise ValueError("Vector b must be 1D") - if A.shape[0] != b.shape[0]: - raise ValueError(f"Dimension mismatch: A has shape {A.shape} but b has length {b.shape[0]}") - class_labels = np.unique(b) + if iop.shape(A)[0] != iop.shape(b)[0]: + raise ValueError(f"Dimension mismatch: A has shape {iop.shape(A)} but b has length {iop.shape(b)[0]}") + class_labels = np.unique(iop.to_numpy(b)) if class_labels.shape != (2,): raise ValueError("Vector b must contain exactly two classes") - b = copy.deepcopy(b) - b[np.where(b == class_labels[0])], b[np.where(b == class_labels[1])] = 0, 1 - self.A = A - self.b = b + + self.A: NDArray[float64] = iop.to_numpy(A) + self.b: NDArray[float64] = iop.to_numpy(iop.copy(b)) # Copy b to avoid modifying original array pointer + self.b[np.where(self.b == class_labels[0])], self.b[np.where(self.b == class_labels[1])] = 0, 1 @property def shape(self) -> tuple[int, ...]: # noqa: D102 return (self.A.shape[1],) + @property + def framework(self) -> SupportedFrameworks: # noqa: D102 + return SupportedFrameworks.NUMPY + + @property + def device(self) -> SupportedDevices: # noqa: D102 + return SupportedDevices.CPU + @cached_property - def m_smooth(self) -> float: + def m_smooth(self) -> float: # pyright: ignore[reportIncompatibleMethodOverride] r""" The cost function's smoothness constant. @@ -384,8 +446,7 @@ def m_smooth(self) -> float: For the general definition, see :attr:`Cost.m_smooth `. """ - res: float = max(pow(la.norm(row), 2) for row in self.A) * self.A.shape[0] / 4 - return res + return float(max(pow(la.norm(row), 2) for row in self.A) * self.A.shape[0] / 4) @property def m_cvx(self) -> float: @@ -397,6 +458,7 @@ def m_cvx(self) -> float: """ return 0 + @iop.autodecorate_cost_method(Cost.function) def function(self, x: NDArray[float64]) -> float: r""" Evaluate function at x. @@ -411,6 +473,7 @@ def function(self, x: NDArray[float64]) -> float: cost = self.b.dot(neg_log_sig) + (1 - self.b).dot(Ax + neg_log_sig) return float(cost) + @iop.autodecorate_cost_method(Cost.gradient) def gradient(self, x: NDArray[float64]) -> NDArray[float64]: r""" Gradient at x. @@ -421,6 +484,7 @@ def gradient(self, x: NDArray[float64]) -> NDArray[float64]: res: NDArray[float64] = self.A.T.dot(sig - self.b) return res + @iop.autodecorate_cost_method(Cost.hessian) def hessian(self, x: NDArray[float64]) -> NDArray[float64]: r""" Hessian at x. @@ -435,15 +499,15 @@ def hessian(self, x: NDArray[float64]) -> NDArray[float64]: res: NDArray[float64] = self.A.T.dot(D).dot(self.A) return res - def proximal(self, y: NDArray[float64], rho: float) -> NDArray[float64]: + def proximal(self, x: Array, rho: float) -> Array: """ - Proximal at y solved using an iterative method. + Proximal at x solved using an iterative method. See :meth:`Cost.proximal() ` for the general proximal definition. """ - return ca.proximal_solver(self, y, rho) + return ca.proximal_solver(self, x, rho) def __add__(self, other: Cost) -> Cost: """ @@ -456,7 +520,10 @@ def __add__(self, other: Cost) -> Cost: if self.shape != other.shape: raise ValueError(f"Mismatching domain shapes: {self.shape} vs {other.shape}") if isinstance(other, LogisticRegressionCost): - return LogisticRegressionCost(np.vstack([self.A, other.A]), np.concatenate([self.b, other.b])) + return LogisticRegressionCost( + iop.to_array(np.vstack([self.A, other.A]), self.framework, self.device), + iop.to_array(np.concatenate([self.b, other.b]), self.framework, self.device), + ) return SumCost([self, other]) @@ -477,8 +544,16 @@ def __init__(self, costs: list[Cost]): def shape(self) -> tuple[int, ...]: # noqa: D102 return self.costs[0].shape + @property + def framework(self) -> SupportedFrameworks: # noqa: D102 + return self.costs[0].framework + + @property + def device(self) -> SupportedDevices: # noqa: D102 + return self.costs[0].device + @cached_property - def m_smooth(self) -> float: + def m_smooth(self) -> float: # pyright: ignore[reportIncompatibleMethodOverride] r""" The cost function's smoothness constant. @@ -496,7 +571,7 @@ def m_smooth(self) -> float: return np.nan if any(np.isnan(v) for v in m_smooth_vals) else sum(m_smooth_vals) @cached_property - def m_cvx(self) -> float: + def m_cvx(self) -> float: # pyright: ignore[reportIncompatibleMethodOverride] r""" The cost function's convexity constant. @@ -513,29 +588,27 @@ def m_cvx(self) -> float: m_cvx_vals = [cf.m_cvx for cf in self.costs] return np.nan if any(np.isnan(v) for v in m_cvx_vals) else sum(m_cvx_vals) - def function(self, x: NDArray[float64]) -> float: + def function(self, x: Array) -> float: """Sum the :meth:`function` of each cost function.""" return sum(cf.function(x) for cf in self.costs) - def gradient(self, x: NDArray[float64]) -> NDArray[float64]: + def gradient(self, x: Array) -> Array: """Sum the :meth:`gradient` of each cost function.""" - res: NDArray[float64] = np.sum([cf.gradient(x) for cf in self.costs], axis=0) - return res + return iop.sum(iop.stack([cf.gradient(x) for cf in self.costs]), dim=0) - def hessian(self, x: NDArray[float64]) -> NDArray[float64]: + def hessian(self, x: Array) -> Array: """Sum the :meth:`hessian` of each cost function.""" - res: NDArray[float64] = np.sum([cf.hessian(x) for cf in self.costs], axis=0) - return res + return iop.sum(iop.stack([cf.hessian(x) for cf in self.costs]), dim=0) - def proximal(self, y: NDArray[float64], rho: float) -> NDArray[float64]: + def proximal(self, x: Array, rho: float) -> Array: """ - Proximal at y solved using an iterative method. + Proximal at x solved using an iterative method. See :meth:`Cost.proximal() ` for the general proximal definition. """ - return ca.proximal_solver(self, y, rho) + return ca.proximal_solver(self, x, rho) def __add__(self, other: Cost) -> SumCost: """Add another cost function.""" diff --git a/decent_bench/datasets.py b/decent_bench/datasets.py index ffb936d..6516d1a 100644 --- a/decent_bench/datasets.py +++ b/decent_bench/datasets.py @@ -1,18 +1,14 @@ from abc import ABC, abstractmethod from collections.abc import Sequence -from typing import NewType +from typing import TypeAlias -from numpy import float64 -from numpy.typing import NDArray from sklearn import datasets -A = NewType("A", NDArray[float64]) -"""Feature matrix type.""" +import decent_bench.utils.interoperability as iop +from decent_bench.utils.array import Array +from decent_bench.utils.types import SupportedDevices, SupportedFrameworks -b = NewType("b", NDArray[float64]) -"""Target vector type.""" - -DatasetPartition = NewType("DatasetPartition", tuple[A, b]) +DatasetPartition: TypeAlias = tuple[Array, Array] # noqa: UP040 """Tuple of (A, b) representing one dataset partition.""" @@ -38,13 +34,22 @@ class SyntheticClassificationData(Dataset): """ - def __init__( - self, n_partitions: int, n_classes: int, n_samples_per_partition: int, n_features: int, seed: int | None = None + def __init__( # noqa: PLR0917 + self, + n_partitions: int, + n_classes: int, + n_samples_per_partition: int, + n_features: int, + framework: SupportedFrameworks, + device: SupportedDevices = SupportedDevices.CPU, + seed: int | None = None, ): self.n_partitions = n_partitions self.n_classes = n_classes self.n_samples_per_partition = n_samples_per_partition self.n_features = n_features + self.framework = framework + self.device = device self.seed = seed def training_partitions(self) -> list[DatasetPartition]: # noqa: D102 @@ -58,5 +63,7 @@ def training_partitions(self) -> list[DatasetPartition]: # noqa: D102 n_classes=self.n_classes, random_state=seed, ) - res.append(DatasetPartition((A(partition[0]), b(partition[1])))) + A = iop.to_array(partition[0], self.framework, self.device) # noqa: N806 + b = iop.to_array(partition[1], self.framework, self.device) + res.append((A, b)) return res diff --git a/decent_bench/distributed_algorithms.py b/decent_bench/distributed_algorithms.py index a06e976..7718e79 100644 --- a/decent_bench/distributed_algorithms.py +++ b/decent_bench/distributed_algorithms.py @@ -1,10 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -import numpy as np -from numpy import float64 -from numpy.typing import NDArray - from decent_bench.networks import P2PNetwork from decent_bench.utils import interoperability as iop @@ -58,12 +54,14 @@ def run(self, network: P2PNetwork) -> None: """ for agent in network.agents(): - x0 = np.zeros(agent.cost.shape) + x0 = iop.zeros(framework=agent.cost.framework, shape=agent.cost.shape, device=agent.cost.device) agent.initialize(x=x0, received_msgs=dict.fromkeys(network.neighbors(agent), x0)) - W = iop.from_numpy_like(network.weights, network.agents()[0].x) # noqa: N806 + + W = network.weights # noqa: N806 for k in range(self.iterations): for i in network.active_agents(k): - neighborhood_avg = iop.sum([W[i, j] * x_j for j, x_j in i.messages.items()], dim=0) + s = iop.stack([W[i, j] * x_j for j, x_j in i.messages.items()]) + neighborhood_avg = iop.sum(s, dim=0) neighborhood_avg += W[i, i] * i.x i.x = neighborhood_avg - self.step_size * i.cost.gradient(i.x) for i in network.active_agents(k): @@ -108,14 +106,14 @@ def run(self, network: P2PNetwork) -> None: """ for agent in network.agents(): - x0 = np.zeros(agent.cost.shape) + x0 = iop.zeros(framework=agent.cost.framework, shape=agent.cost.shape, device=agent.cost.device) agent.initialize( x=x0, received_msgs=dict.fromkeys(network.neighbors(agent), x0), aux_vars={"y": x0}, ) - W = iop.from_numpy_like(network.weights, network.agents()[0].x) # noqa: N806 + W = network.weights # noqa: N806 for k in range(self.iterations): # gradient step (a.k.a. adapt step) for i in network.active_agents(k): @@ -173,15 +171,17 @@ def run(self, network: P2PNetwork) -> None: """ for agent in network.agents(): - x0 = np.zeros(agent.cost.shape) - y0 = np.zeros(agent.cost.shape) + x0 = iop.zeros(framework=agent.cost.framework, shape=agent.cost.shape, device=agent.cost.device) + y0 = iop.zeros(framework=agent.cost.framework, shape=agent.cost.shape, device=agent.cost.device) neighbors = network.neighbors(agent) agent.initialize(x=x0, received_msgs=dict.fromkeys(neighbors, x0), aux_vars={"y": y0}) - W = iop.from_numpy_like(network.weights, network.agents()[0].x) # noqa: N806 + + W = network.weights # noqa: N806 for k in range(self.iterations): for i in network.active_agents(k): i.aux_vars["y_new"] = i.x - self.step_size * i.cost.gradient(i.x) - neighborhood_avg = iop.sum([W[i, j] * x_j for j, x_j in i.messages.items()], dim=0) + s = iop.stack([W[i, j] * x_j for j, x_j in i.messages.items()]) + neighborhood_avg = iop.sum(s, dim=0) neighborhood_avg += W[i, i] * i.x i.x = i.aux_vars["y_new"] - i.aux_vars["y"] + neighborhood_avg i.aux_vars["y"] = i.aux_vars["y_new"] @@ -229,8 +229,8 @@ def run(self, network: P2PNetwork) -> None: """ for i in network.agents(): - x0 = np.zeros(i.cost.shape) - y0 = np.zeros(i.cost.shape) + x0 = iop.zeros(framework=i.cost.framework, shape=i.cost.shape, device=i.cost.device) + y0 = iop.zeros(framework=i.cost.framework, shape=i.cost.shape, device=i.cost.device) y1 = x0 - self.step_size * i.cost.gradient(x0) # note: msg0's y1 is an approximation of the neighbors' y1 (x0 and y0 are exact: all agents start with same) msg0 = x0 + y1 - y0 @@ -239,15 +239,13 @@ def run(self, network: P2PNetwork) -> None: aux_vars={"y": y0, "y_new": y1}, received_msgs=dict.fromkeys(network.neighbors(i), msg0), ) - W = iop.from_numpy_like( # noqa: N806 - 0.5 * (np.eye(*(network.weights.shape)) + network.weights), - network.agents()[0].x, - ) + + W = network.weights # noqa: N806 + W = 0.5 * (iop.eye_like(W) + W) # noqa: N806 for k in range(self.iterations): for i in network.active_agents(k): - i.x = iop.sum([W[i, j] * msg for j, msg in i.messages.items()], dim=0) + W[i, i] * ( - i.x + i.aux_vars["y_new"] - i.aux_vars["y"] - ) + s = iop.stack([W[i, j] * msg for j, msg in i.messages.items()]) + i.x = iop.sum(s, dim=0) + W[i, i] * (i.x + i.aux_vars["y_new"] - i.aux_vars["y"]) i.aux_vars["y"] = i.aux_vars["y_new"] i.aux_vars["y_new"] = i.x - self.step_size * i.cost.gradient(i.x) for i in network.active_agents(k): @@ -301,7 +299,7 @@ def run(self, network: P2PNetwork) -> None: """ for i in network.agents(): - x0 = np.zeros(i.cost.shape) + x0 = iop.zeros(framework=i.cost.framework, shape=i.cost.shape, device=i.cost.device) y0 = i.cost.gradient(x0) neighbors = network.neighbors(i) i.initialize( @@ -310,8 +308,7 @@ def run(self, network: P2PNetwork) -> None: aux_vars={"y": y0, "g": y0, "g_new": x0, "s": x0}, ) - W = iop.from_numpy_like(network.weights, network.agents()[0].x) # noqa: N806 - + W = network.weights # noqa: N806 for k in range(self.iterations): # 1st communication round # step 1: perform local gradient step and communicate @@ -326,7 +323,7 @@ def run(self, network: P2PNetwork) -> None: network.receive_all(i) for i in network.active_agents(k): - s = iop.stack([W[i, j] * s_j for j, s_j in i.messages.items()], dim=0) + s = iop.stack([W[i, j] * s_j for j, s_j in i.messages.items()]) neighborhood_avg = iop.sum(s, dim=0) neighborhood_avg += W[i, i] * i.aux_vars["s"] i.x = neighborhood_avg @@ -342,7 +339,7 @@ def run(self, network: P2PNetwork) -> None: network.receive_all(i) for i in network.active_agents(k): - s = iop.stack([W[i, j] * q_j for j, q_j in i.messages.items()], dim=0) + s = iop.stack([W[i, j] * q_j for j, q_j in i.messages.items()]) neighborhood_avg = iop.sum(s, dim=0) neighborhood_avg += W[i, i] * (i.aux_vars["y"] + i.aux_vars["g_new"] - i.aux_vars["g"]) i.aux_vars["y"] = neighborhood_avg @@ -395,7 +392,7 @@ def run(self, network: P2PNetwork) -> None: """ for i in network.agents(): - x0 = np.zeros(i.cost.shape) + x0 = iop.zeros(framework=i.cost.framework, shape=i.cost.shape, device=i.cost.device) neighbors = network.neighbors(i) i.initialize( x=x0, @@ -403,7 +400,7 @@ def run(self, network: P2PNetwork) -> None: aux_vars={"z": x0, "x_old": x0}, ) - W = iop.from_numpy_like(network.weights, network.agents()[0].x) # noqa: N806 + W = network.weights # noqa: N806 K = 0.5 * (iop.eye_like(W) - W) # noqa: N806 for k in range(self.iterations): @@ -416,7 +413,8 @@ def run(self, network: P2PNetwork) -> None: network.receive_all(i) for i in network.active_agents(k): - neighborhood_avg = iop.sum([K[i, j] * m_j for j, m_j in i.messages.items()], dim=0) + s = iop.stack([K[i, j] * m_j for j, m_j in i.messages.items()]) + neighborhood_avg = iop.sum(s, dim=0) neighborhood_avg += K[i, i] * (i.x + i.aux_vars["z"]) i.aux_vars["x_old"] = i.x @@ -431,7 +429,8 @@ def run(self, network: P2PNetwork) -> None: network.receive_all(i) for i in network.active_agents(k): - neighborhood_avg = iop.sum([K[i, j] * m_j for j, m_j in i.messages.items()], dim=0) + s = iop.stack([K[i, j] * m_j for j, m_j in i.messages.items()]) + neighborhood_avg = iop.sum(s, dim=0) neighborhood_avg += K[i, i] * i.aux_vars["x_old"] i.aux_vars["z"] += neighborhood_avg @@ -475,22 +474,21 @@ def run(self, network: P2PNetwork) -> None: """ # initialization (iteration k=0) for i in network.agents(): - x0 = np.zeros(i.cost.shape) + x0 = iop.zeros(framework=i.cost.framework, shape=i.cost.shape, device=i.cost.device) i.initialize( x=x0, received_msgs=dict.fromkeys(network.neighbors(i), x0), aux_vars={"x_old": x0, "x_old_old": x0, "x_cons": x0}, ) - W = iop.from_numpy_like(network.weights, network.agents()[0].x) # noqa: N806 - + W = network.weights # noqa: N806 # first iteration (iteration k=1) for i in network.active_agents(0): network.broadcast(i, i.x) for i in network.active_agents(0): network.receive_all(i) for i in network.active_agents(0): - s = iop.stack([W[i, j] * x_j for j, x_j in i.messages.items()], dim=0) + s = iop.stack([W[i, j] * x_j for j, x_j in i.messages.items()]) neighborhood_avg = iop.sum(s, dim=0) neighborhood_avg += W[i, i] * i.x i.aux_vars["x_cons"] = neighborhood_avg # store W x_k @@ -504,7 +502,7 @@ def run(self, network: P2PNetwork) -> None: for i in network.active_agents(k): network.receive_all(i) for i in network.active_agents(k): - s = iop.stack([W[i, j] * x_j for j, x_j in i.messages.items()], dim=0) + s = iop.stack([W[i, j] * x_j for j, x_j in i.messages.items()]) neighborhood_avg = iop.sum(s, dim=0) neighborhood_avg += W[i, i] * i.x i.aux_vars["x_old_old"] = i.aux_vars["x_old"] # store x_{k-1} @@ -566,7 +564,7 @@ def run(self, network: P2PNetwork) -> None: """ for i in network.agents(): - x0 = np.zeros(i.cost.shape) + x0 = iop.zeros(framework=i.cost.framework, shape=i.cost.shape, device=i.cost.device) y0 = i.cost.gradient(x0) neighbors = network.neighbors(i) i.initialize( @@ -575,8 +573,7 @@ def run(self, network: P2PNetwork) -> None: aux_vars={"y": y0, "g": y0, "g_new": x0, "s": x0}, ) - W = iop.from_numpy_like(network.weights, network.agents()[0].x) # noqa: N806 - + W = network.weights # noqa: N806 for k in range(self.iterations): # 1st communication round # step 1: perform local gradient step and communicate @@ -591,7 +588,7 @@ def run(self, network: P2PNetwork) -> None: network.receive_all(i) for i in network.active_agents(k): - s = iop.stack([W[i, j] * s_j for j, s_j in i.messages.items()], dim=0) + s = iop.stack([W[i, j] * s_j for j, s_j in i.messages.items()]) neighborhood_avg = iop.sum(s, dim=0) neighborhood_avg += W[i, i] * i.aux_vars["s"] i.x = neighborhood_avg @@ -607,7 +604,8 @@ def run(self, network: P2PNetwork) -> None: network.receive_all(i) for i in network.active_agents(k): - neighborhood_avg = iop.sum([W[i, j] * q_j for j, q_j in i.messages.items()], dim=0) + s = iop.stack([W[i, j] * q_j for j, q_j in i.messages.items()]) + neighborhood_avg = iop.sum(s, dim=0) neighborhood_avg += W[i, i] * i.aux_vars["y"] i.aux_vars["y"] = neighborhood_avg + i.aux_vars["g_new"] - i.aux_vars["g"] i.aux_vars["g"] = i.aux_vars["g_new"] @@ -658,14 +656,14 @@ def run(self, network: P2PNetwork) -> None: """ # initialization (iteration k=0) for i in network.agents(): - x0 = np.zeros(i.cost.shape) + x0 = iop.zeros(framework=i.cost.framework, shape=i.cost.shape, device=i.cost.device) i.initialize( x=x0, received_msgs=dict.fromkeys(network.neighbors(i), x0), aux_vars={"x_old": x0, "g": x0, "g_old": x0, "y": x0}, ) - W = iop.from_numpy_like(network.weights, network.agents()[0].x) # noqa: N806 + W = network.weights # noqa: N806 W_tilde = 0.5 * (iop.eye_like(W) + W) # noqa: N806 # first iteration (iteration k=1) @@ -689,14 +687,8 @@ def run(self, network: P2PNetwork) -> None: for i in network.active_agents(k): network.receive_all(i) for i in network.active_agents(k): - s = iop.stack( - [W_tilde[i, j] * y_j for j, y_j in i.messages.items()], - dim=0, - ) - neighborhood_avg = iop.sum( - s, - dim=0, - ) + s = iop.stack([W_tilde[i, j] * y_j for j, y_j in i.messages.items()]) + neighborhood_avg = iop.sum(s, dim=0) neighborhood_avg += W_tilde[i, i] * i.aux_vars["y"] i.aux_vars["x_old"] = i.x # store x_k i.x = neighborhood_avg # update x_{k+1} @@ -741,10 +733,14 @@ def run(self, network: P2PNetwork) -> None: pN = {i: self.rho * len(network.neighbors(i)) for i in network.agents()} # noqa: N806 all_agents = network.agents() for agent in all_agents: - z0 = np.zeros((len(all_agents), *(agent.cost.shape))) - x1 = agent.cost.proximal(y=np.sum(z0, axis=0) / pN[agent], rho=1 / pN[agent]) + z0 = iop.zeros( + framework=agent.cost.framework, + shape=(len(all_agents), *(agent.cost.shape)), + device=agent.cost.device, + ) + x1 = agent.cost.proximal(x=iop.sum(z0, dim=0) / pN[agent], rho=1 / pN[agent]) # note: msg0's x1 is an approximation of the neighbors' x1 (z0 is exact: all agents start with same) - msg0: NDArray[float64] = z0[agent] - 2 * self.rho * x1 + msg0 = z0[agent] - 2 * self.rho * x1 agent.initialize( x=x1, aux_vars={"z": z0}, @@ -752,7 +748,7 @@ def run(self, network: P2PNetwork) -> None: ) for k in range(self.iterations): for i in network.active_agents(k): - i.x = i.cost.proximal(y=iop.sum(i.aux_vars["z"], dim=0) / pN[i], rho=1 / pN[i]) + i.x = i.cost.proximal(x=iop.sum(i.aux_vars["z"], dim=0) / pN[i], rho=1 / pN[i]) for i in network.active_agents(k): for j in network.neighbors(i): network.send(i, j, i.aux_vars["z"][j] - 2 * self.rho * i.x) @@ -820,9 +816,9 @@ def run(self, network: P2PNetwork) -> None: pN = {i: self.rho * len(network.neighbors(i)) for i in network.agents()} # noqa: N806 all_agents = network.agents() for i in all_agents: - z_y0 = np.zeros((len(all_agents), *(i.cost.shape))) - z_s0 = np.zeros((len(all_agents), *(i.cost.shape))) - x0 = np.zeros(i.cost.shape) + z_y0 = iop.zeros(framework=i.cost.framework, shape=(len(all_agents), *(i.cost.shape)), device=i.cost.device) + z_s0 = iop.zeros(framework=i.cost.framework, shape=(len(all_agents), *(i.cost.shape)), device=i.cost.device) + x0 = iop.zeros(framework=i.cost.framework, shape=i.cost.shape, device=i.cost.device) i.initialize( x=x0, aux_vars={"y": x0, "s": x0, "z_y": z_y0, "z_s": z_s0}, @@ -841,13 +837,14 @@ def run(self, network: P2PNetwork) -> None: for i in network.active_agents(k): for j in network.neighbors(i): # transmit the messages as a single message, stacking along the first axis - network.send(i, j, - iop.stack( - ( - -i.aux_vars["z_y"][j] + 2 * self.rho * i.aux_vars["y"], - -i.aux_vars["z_s"][j] + 2 * self.rho * i.aux_vars["s"], - ), dim=0), - ) # fmt: skip + s = iop.stack( + ( + -i.aux_vars["z_y"][j] + 2 * self.rho * i.aux_vars["y"], + -i.aux_vars["z_s"][j] + 2 * self.rho * i.aux_vars["s"], + ), + dim=0, + ) + network.send(i, j, s) for i in network.active_agents(k): network.receive_all(i) for i in network.active_agents(k): @@ -907,10 +904,12 @@ def run(self, network: P2PNetwork) -> None: """ all_agents = network.agents() for i in all_agents: - x0 = np.zeros(i.cost.shape) + x0 = iop.zeros(framework=i.cost.framework, shape=i.cost.shape, device=i.cost.device) + # y must be initialized to zero + y = iop.zeros(framework=i.cost.framework, shape=i.cost.shape, device=i.cost.device) i.initialize( x=x0, - aux_vars={"y": np.zeros(i.cost.shape)}, # y must be initialized to zero + aux_vars={"y": y}, ) # step 0: first communication round @@ -920,7 +919,7 @@ def run(self, network: P2PNetwork) -> None: network.receive_all(i) # compute and store \sum_j (\mathbf{x}_{i,0} - \mathbf{x}_{j,0}) for i in network.active_agents(0): - s = iop.stack([i.x - x_j for _, x_j in i.messages.items()], dim=0) + s = iop.stack([i.x - x_j for _, x_j in i.messages.items()]) i.aux_vars["s"] = iop.sum(s, dim=0) # pyright: ignore[reportArgumentType] # main iteration @@ -938,7 +937,7 @@ def run(self, network: P2PNetwork) -> None: network.receive_all(i) # compute and store \sum_j (\mathbf{x}_{i,k+1} - \mathbf{x}_{j,k+1}) for i in network.active_agents(k): - s = iop.stack([i.x - x_j for _, x_j in i.messages.items()], dim=0) + s = iop.stack([i.x - x_j for _, x_j in i.messages.items()]) i.aux_vars["s"] = iop.sum(s, dim=0) # pyright: ignore[reportArgumentType] # step 3: update dual variable diff --git a/decent_bench/metrics/metric_utils.py b/decent_bench/metrics/metric_utils.py index 0fb3eba..dfac782 100644 --- a/decent_bench/metrics/metric_utils.py +++ b/decent_bench/metrics/metric_utils.py @@ -7,8 +7,10 @@ from numpy.linalg import LinAlgError from numpy.typing import NDArray +import decent_bench.utils.interoperability as iop from decent_bench.agents import AgentMetricsView from decent_bench.benchmark_problem import BenchmarkProblem +from decent_bench.utils.array import Array def single(values: Sequence[float]) -> float: @@ -25,7 +27,7 @@ def single(values: Sequence[float]) -> float: @cache -def x_mean(agents: tuple[AgentMetricsView, ...], iteration: int = -1) -> NDArray[float64]: +def x_mean(agents: tuple[AgentMetricsView, ...], iteration: int = -1) -> Array: """ Calculate the mean x at *iteration* (or using the agents' final x if *iteration* is -1). @@ -38,8 +40,8 @@ def x_mean(agents: tuple[AgentMetricsView, ...], iteration: int = -1) -> NDArray all_x_at_iter = [a.x_history[iteration] for a in agents if len(a.x_history) > iteration] if len(all_x_at_iter) == 0: raise ValueError(f"No agent reached iteration {iteration}") - res: NDArray[float64] = np.mean(all_x_at_iter, axis=0) - return res + + return iop.mean(iop.stack(all_x_at_iter), dim=0) def regret(agents: list[AgentMetricsView], problem: BenchmarkProblem, iteration: int = -1) -> float: @@ -66,7 +68,7 @@ def gradient_norm(agents: list[AgentMetricsView], iteration: int = -1) -> float: .. include:: snippets/global_gradient_optimality.rst """ mean_x = x_mean(tuple(agents), iteration) - grad_avg = sum(a.cost.gradient(mean_x) for a in agents) / len(agents) + grad_avg = sum(iop.to_numpy(a.cost.gradient(mean_x)) for a in agents) / len(agents) return float(la.norm(grad_avg)) ** 2 @@ -81,7 +83,7 @@ def x_error(agent: AgentMetricsView, problem: BenchmarkProblem) -> NDArray[float where :math:`\mathbf{x}_k` is the agent's local x at iteration k, and :math:`\mathbf{x}^\star` is the optimal x defined in the *problem*. """ - x_per_iteration = np.asarray(agent.x_history) + x_per_iteration = np.asarray([iop.to_numpy(x) for x in agent.x_history]) opt_x = problem.x_optimal errors: NDArray[float64] = la.norm(x_per_iteration - opt_x, axis=tuple(range(1, x_per_iteration.ndim))) return errors diff --git a/decent_bench/metrics/table_metrics.py b/decent_bench/metrics/table_metrics.py index 9cb95fa..2bc1130 100644 --- a/decent_bench/metrics/table_metrics.py +++ b/decent_bench/metrics/table_metrics.py @@ -9,6 +9,7 @@ from scipy import stats import decent_bench.metrics.metric_utils as utils +import decent_bench.utils.interoperability as iop from decent_bench.agents import AgentMetricsView from decent_bench.benchmark_problem import BenchmarkProblem from decent_bench.distributed_algorithms import Algorithm @@ -87,7 +88,7 @@ class XError(TableMetric): description: str = "x error" def get_data_from_trial(self, agents: list[AgentMetricsView], problem: BenchmarkProblem) -> list[float]: # noqa: D102 - return [float(la.norm(problem.x_optimal - a.x_history[-1])) for a in agents] + return [float(la.norm(iop.to_numpy(problem.x_optimal) - iop.to_numpy(a.x_history[-1]))) for a in agents] class AsymptoticConvergenceOrder(TableMetric): diff --git a/decent_bench/networks.py b/decent_bench/networks.py index d4f3fb5..a81095d 100644 --- a/decent_bench/networks.py +++ b/decent_bench/networks.py @@ -7,17 +7,17 @@ import networkx as nx import numpy as np from networkx import Graph -from numpy import float64 -from numpy.typing import NDArray +import decent_bench.utils.interoperability as iop from decent_bench.agents import Agent from decent_bench.benchmark_problem import BenchmarkProblem from decent_bench.schemes import CompressionScheme, DropScheme, NoiseScheme +from decent_bench.utils.array import Array if TYPE_CHECKING: - AgentGraph = Graph[Agent] + AgentGraph = nx.Graph[Agent] else: - AgentGraph = Graph + AgentGraph = nx.Graph class Network(ABC): @@ -73,7 +73,7 @@ def active_agents(self, iteration: int) -> list[Agent]: """ return [a for a in self.agents() if a._activation.is_active(iteration)] # noqa: SLF001 - def send(self, sender: Agent, receiver: Agent, msg: NDArray[float64]) -> None: + def send(self, sender: Agent, receiver: Agent, msg: Array) -> None: """ Send message to a neighbor. @@ -123,19 +123,41 @@ def __init__( message_compression=message_compression, message_drop=message_drop, ) + self.W: Array | None = None def kind(self) -> str: """Label for the network subtype.""" return "p2p" - @cached_property - def weights(self) -> NDArray[float64]: + def set_weights(self, weights: Array) -> None: + """ + Set custom consensus weights matrix. + + A simple way to create custom weights is to start using numpy and then + use :func:`~decent_bench.utils.interoperability.to_array` to convert to an + :class:`~decent_bench.utils.array.Array` object with the desired framework and device. + For an example see :func:`~decent_bench.utils.interoperability.zeros`. + + Note: + If not set, the weights matrix is initialized using the Metropolis-Hastings method. + Weights will be overwritten if framework or device differ from + ``Agent.cost.framework`` or ``Agent.cost.device``. + + """ + self.W = weights + + @property + def weights(self) -> Array: """ Symmetric, doubly stochastic matrix for consensus weights. Initialized using the Metropolis-Hastings method. Use ``weights[i, j]`` or ``weights[i.id, j.id]`` to get the weight between agent i and j. """ agents = self.agents() + + if self.W is not None: + return self.W + n = len(agents) W = np.zeros((n, n)) # noqa: N806 for i in agents: @@ -146,13 +168,32 @@ def weights(self) -> NDArray[float64]: W[i, j] = 1 / (1 + max(d_i, d_j)) for i in agents: W[i, i] = 1 - sum(W[i]) - return W + self.W = iop.to_array(W, agents[0].cost.framework, agents[0].cost.device) + return self.W + + @cached_property + def adjacency(self) -> Array: + """ + Adjacency matrix of the network. + + Use ``adjacency[i, j]`` or ``adjacency[i.id, j.id]`` to get the adjacency between agent i and j. + """ + agents = self.agents() + n = len(agents) + A = np.zeros((n, n)) # noqa: N806 + for i in agents: + for j in self.neighbors(i): + A[i, j] = 1 + + return iop.to_array(A, agents[0].cost.framework, agents[0].cost.device) + def neighbors(self, agent: Agent) -> list[Agent]: """Get all neighbors of an agent.""" return list(self.graph[agent]) - def broadcast(self, sender: Agent, msg: NDArray[float64]) -> None: + + def broadcast(self, sender: Agent, msg: Array) -> None: """ Send message to all neighbors. @@ -228,7 +269,7 @@ def active_clients(self, iteration: int) -> list[Agent]: """ return [agent for agent in self.active_agents(iteration) if agent is not self.server] - def send_to_client(self, client: Agent, msg: NDArray[float64]) -> None: + def send_to_client(self, client: Agent, msg: Array) -> None: """ Send a message from the server to a specific client. @@ -240,12 +281,12 @@ def send_to_client(self, client: Agent, msg: NDArray[float64]) -> None: raise ValueError("Receiver must be a client") self.send(sender=self.server, receiver=client, msg=msg) - def send_to_all_clients(self, msg: NDArray[float64]) -> None: + def send_to_all_clients(self, msg: Array) -> None: """Send the same message from the server to every client (synchronous FL push).""" for client in self.clients: self.send_to_client(client, msg) - def send_from_client(self, client: Agent, msg: NDArray[float64]) -> None: + def send_from_client(self, client: Agent, msg: Array) -> None: """ Send a message from a client to the server. @@ -257,7 +298,7 @@ def send_from_client(self, client: Agent, msg: NDArray[float64]) -> None: raise ValueError("Sender must be a client") self.send(sender=client, receiver=self.server, msg=msg) - def send_from_all_clients(self, msgs: Mapping[Agent, NDArray[float64]]) -> None: + def send_from_all_clients(self, msgs: Mapping[Agent, Array]) -> None: """ Send messages from each client to the server (synchronous FL push). diff --git a/decent_bench/schemes.py b/decent_bench/schemes.py index 1464b61..addc2d1 100644 --- a/decent_bench/schemes.py +++ b/decent_bench/schemes.py @@ -1,11 +1,10 @@ import random from abc import ABC, abstractmethod -from functools import cached_property import numpy as np -from numpy import float64 -from numpy.random import MT19937, Generator -from numpy.typing import NDArray + +import decent_bench.utils.interoperability as iop +from decent_bench.utils.array import Array class AgentActivationScheme(ABC): @@ -43,14 +42,14 @@ class CompressionScheme(ABC): """Scheme defining how messages are compressed when sent over the network.""" @abstractmethod - def compress(self, msg: NDArray[float64]) -> NDArray[float64]: + def compress(self, msg: Array) -> Array: """Apply compression and return a new, compressed message.""" class NoCompression(CompressionScheme): """Scheme that leaves messages uncompressed.""" - def compress(self, msg: NDArray[float64]) -> NDArray[float64]: # noqa: D102 + def compress(self, msg: Array) -> Array: # noqa: D102 return msg @@ -60,9 +59,9 @@ class Quantization(CompressionScheme): def __init__(self, n_significant_digits: int): self.n_significant_digits = n_significant_digits - def compress(self, msg: NDArray[float64]) -> NDArray[float64]: # noqa: D102 - res: NDArray[float64] = np.vectorize(lambda x: float(f"%.{self.n_significant_digits - 1}e" % x))(msg) - return res + def compress(self, msg: Array) -> Array: # noqa: D102 + res = np.vectorize(lambda x: float(f"%.{self.n_significant_digits - 1}e" % x))(iop.to_numpy(msg)) + return iop.to_array_like(res, msg) class DropScheme(ABC): @@ -96,14 +95,14 @@ class NoiseScheme(ABC): """Scheme defining how noise impacts messages.""" @abstractmethod - def make_noise(self, msg: NDArray[float64]) -> NDArray[float64]: + def make_noise(self, msg: Array) -> Array: """Apply noise scheme without mutating the *msg* passed in.""" class NoNoise(NoiseScheme): """Scheme that leaves messages untouched.""" - def make_noise(self, msg: NDArray[float64]) -> NDArray[float64]: # noqa: D102 + def make_noise(self, msg: Array) -> Array: # noqa: D102 return msg @@ -116,9 +115,5 @@ def __init__(self, mean: float, sd: float): self.mean = mean self.sd = sd - @cached_property - def _generator(self) -> Generator: - return Generator(MT19937()) - - def make_noise(self, msg: NDArray[float64]) -> NDArray[float64]: # noqa: D102 - return msg + self._generator.normal(self.mean, self.sd, msg.shape) + def make_noise(self, msg: Array) -> Array: # noqa: D102 + return msg + iop.randn_like(msg, mean=self.mean, std=self.sd) diff --git a/decent_bench/utils/array.py b/decent_bench/utils/array.py new file mode 100644 index 0000000..44b1824 --- /dev/null +++ b/decent_bench/utils/array.py @@ -0,0 +1,347 @@ +from __future__ import annotations + +from collections.abc import Iterator +from typing import Self + +import decent_bench.utils.interoperability as iop +from decent_bench.utils.types import ArrayKey, SupportedArrayTypes + + +class Array: # noqa: PLR0904 + """ + A wrapper class for :data:`~decent_bench.utils.types.SupportedArrayTypes` objects to enable operator overloading. + + This class allows for seamless interoperability between different array/tensor frameworks + by overloading standard arithmetic operators. Operations supported are addition, subtraction, multiplication, + division, matrix multiplication, exponentiation, negation and in-place operations. + + Note: + Instantiation of this class is typically done through the functions in + :mod:`~decent_bench.utils.interoperability` rather than direct instantiation. + This is to ensure proper handling of different underlying array types. + + """ + + def __init__( + self, + value: SupportedArrayTypes, + ): + """ + Initialize the Array object. + + Can be initialized either by providing a array-like `value` or using one + of the methods in :mod:`decent_bench.utils.interoperability`. + + Args: + value (SupportedArrayTypes): The array-like object to wrap. + + """ + self.value: SupportedArrayTypes = value + + def __add__(self, other: Array | SupportedArrayTypes) -> Array: + """ + Add another Array object or SupportedArrayTypes to this one. + + Args: + other: The object to add. + + Returns: + The result of the addition. + + """ + return iop.add(self, other) + + __radd__ = __add__ + + def __sub__(self, other: Array | SupportedArrayTypes) -> Array: + """ + Subtract another Array object or a scalar from this one. + + Args: + other: The object to subtract. + + Returns: + The result of the subtraction. + + """ + return iop.sub(self, other) + + def __rsub__(self, other: SupportedArrayTypes) -> Array: + """ + Handle right-side subtraction. + + Args: + other: The object to be subtracted from. + + Returns: + The result of the subtraction. + + """ + return iop.sub(other, self) + + def __mul__(self, other: Array | SupportedArrayTypes) -> Array: + """ + Multiply this object by another Array object or a scalar. + + Args: + other: The object to multiply by. + + Returns: + The result of the multiplication. + + """ + return iop.mul(self, other) + + def __truediv__(self, other: Array | SupportedArrayTypes) -> Array: + """ + Divide this object by another Array object or a scalar. + + Args: + other: The object to divide by. + + Returns: + The result of the division. + + """ + return iop.div(self, other) + + def __matmul__(self, other: Array | SupportedArrayTypes) -> Array: + """ + Perform matrix multiplication with another Array object. + + Args: + other: The object to multiply with. + + Returns: + The result of the matrix multiplication. + + """ + return iop.matmul(self, other) + + def __rmatmul__(self, other: SupportedArrayTypes) -> Array: + """ + Perform right-side matrix multiplication with another Array object. + + Args: + other: The object to multiply with. + + Returns: + The result of the matrix multiplication. + + """ + return iop.matmul(other, self) + + def __rmul__(self, other: SupportedArrayTypes) -> Array: + """ + Handle right-side multiplication. + + Args: + other: The object to multiply by. + + Returns: + The result of the multiplication. + + """ + return iop.mul(other, self) + + def __rtruediv__(self, other: SupportedArrayTypes) -> Array: + """ + Handle right-side division. + + Args: + other: The object to be divided by the array. + + Returns: + The result of the division. + + """ + return iop.div(other, self) + + def __pow__(self, other: float) -> Array: + """ + Raise the wrapped tensor to a power. + + Args: + other: The power. + + Returns: + The result of the operation. + + """ + return iop.power(self, other) + + def __iadd__(self, other: Array | SupportedArrayTypes) -> Self: + """ + Perform in-place addition. + + Args: + other: The object to add. + + Returns: + The modified object. + + """ + return iop.ext.iadd(self, other) + + def __isub__(self, other: Array | SupportedArrayTypes) -> Self: + """ + Perform in-place subtraction. + + Args: + other: The object to subtract. + + Returns: + The modified object. + + """ + return iop.ext.isub(self, other) + + def __imul__(self, other: Array | SupportedArrayTypes) -> Self: + """ + Perform in-place multiplication. + + Args: + other: The object to multiply by. + + Returns: + The modified object. + + """ + return iop.ext.imul(self, other) + + def __itruediv__(self, other: Array | SupportedArrayTypes) -> Self: + """ + Perform in-place division. + + Args: + other: The object to divide by. + + Returns: + The modified object. + + """ + return iop.ext.idiv(self, other) + + def __ipow__(self, other: float) -> Self: + """ + Perform in-place power operation. + + Args: + other: The power. + + Returns: + The modified object. + + """ + return iop.ext.ipow(self, other) + + def __neg__(self) -> Array: + """ + Negate the wrapped tensor. + + Returns: + The negated tensor. + + """ + return iop.negative(self) + + def __abs__(self) -> Array: + """ + Return the absolute value of the wrapped tensor. + + Returns: + The absolute value. + + """ + return iop.absolute(self) + + def __getitem__(self, key: ArrayKey) -> Array: + """ + Get an item or slice from the wrapped tensor. + + Args: + key (ArrayKey): The key or slice. + + Returns: + The resulting item or slice. + + Raises: + TypeError: If the wrapped value is a scalar. + + """ + if isinstance(self.value, (float, int, complex)): + raise TypeError("Scalar values do not support indexing.") + + return iop.get_item(self, key) + + def __setitem__(self, key: ArrayKey, value: Array | SupportedArrayTypes) -> None: + """ + Set an item or slice in the wrapped tensor. + + Be aware that this operation may not be supported by all underlying frameworks. + JAX and TensorFlow, for example, use immutable arrays by default. + + Args: + key (ArrayKey): The key or slice. + value: The value to set. + + Raises: + TypeError: If the wrapped value is a scalar. + + """ + if isinstance(self.value, (float, int, complex)): + raise TypeError("Scalar values do not support indexing.") + + iop.set_item(self, key, value) + + def __repr__(self) -> str: + """Return the official string representation of the object.""" + return f"Array({self.value!r})" + + def __str__(self) -> str: + """Return the user-friendly string representation of the object.""" + return str(self.value) + + def __len__(self) -> int: + """ + Return the length of the first dimension. + + Raises: + TypeError: If the wrapped value is a scalar. + + """ + if isinstance(self.value, (float, int, complex)): + raise TypeError("Scalar values do not have length.") + return len(self.value) + + def __iter__(self) -> Iterator[SupportedArrayTypes]: + """ + Return an iterator over the first dimension, yielding array elements. + + Raises: + TypeError: If the wrapped value is a scalar. + + """ + if isinstance(self.value, (float, int, complex)): + raise TypeError("Scalar values are not iterable.") + return iter(self.value) + + def __array__(self) -> SupportedArrayTypes: # noqa: PLW3201 + """ + Return the underlying array-like object. + + Returns: + The wrapped tensor-like object. + + """ + return self.value + + def __float__(self) -> float: + """ + Return the wrapped tensor as a float. + + Returns: + The float representation of the wrapped tensor. + + """ + return iop.astype(self, float) diff --git a/decent_bench/utils/interoperability.py b/decent_bench/utils/interoperability.py deleted file mode 100644 index 8bd4356..0000000 --- a/decent_bench/utils/interoperability.py +++ /dev/null @@ -1,690 +0,0 @@ -"""Utilities for operating on arrays from different deep learning and linear algebra frameworks.""" - -from __future__ import annotations - -import contextlib -import random -from collections.abc import Sequence -from copy import deepcopy -from typing import Any, TypeVar, cast - -import numpy as np -from numpy.typing import ArrayLike, NDArray - -torch = None -with contextlib.suppress(ImportError, ModuleNotFoundError): - import torch as _torch - - torch = _torch - -tf = None -with contextlib.suppress(ImportError, ModuleNotFoundError): - import tensorflow as _tf - - tf = _tf - -jnp = None -with contextlib.suppress(ImportError, ModuleNotFoundError): - import jax.numpy as _jnp - - jnp = _jnp - -jax_random = None -with contextlib.suppress(ImportError, ModuleNotFoundError): - from jax import random as _jax_random - - jax_random = _jax_random - -T = TypeVar("T", bound=ArrayLike) -""" -TypeVar for ArrayLike types such as NumPy arrays, PyTorch tensors or -nested containers (lists/tuples). - -This TypeVar is used throughout the interoperability utilities to ensure that -operations preserve the input framework type. -""" - - -def to_numpy(array: ArrayLike) -> NDArray[Any]: - """ - Convert input array to a NumPy array. - - Args: - array (ArrayLike): Input array (NumPy, torch, tf, jax) or nested container (list, tuple). - - Returns: - NDArray: Converted NumPy array. - - """ - if isinstance(array, np.ndarray): - return array - if torch and isinstance(array, torch.Tensor): - return cast("np.ndarray", array.cpu().numpy()) - if tf and isinstance(array, tf.Tensor): - return cast("np.ndarray", array.numpy()) - if (jnp and isinstance(array, jnp.ndarray)) or isinstance(array, (list, tuple)): - return np.array(array) - return np.array(array) - - -def from_numpy_like[T: ArrayLike](array: NDArray[Any], like: T) -> T: - """ - Convert a NumPy array to the specified framework type. - - Args: - array (NDArray): Input NumPy array. - like (ArrayLike): Example array of the target framework type (e.g., torch.Tensor, tf.Tensor). - - Returns: - ArrayLike: Converted array in the specified framework type. - - Raises: - TypeError: if the framework type of `like` is unsupported. - - """ - device = None - if hasattr(like, "device"): - device = like.device - - if isinstance(like, np.ndarray): - return cast("T", array) - if torch and isinstance(like, torch.Tensor): - return cast("T", torch.from_numpy(array).to(device)) - if tf and isinstance(like, tf.Tensor): - with tf.device(device): - return cast("T", tf.convert_to_tensor(array)) - if jnp and isinstance(like, jnp.ndarray): - return cast("T", jnp.array(array, device=device)) - if isinstance(like, (list, tuple)): - return cast("T", type(like)(array.tolist())) - raise TypeError(f"Unsupported framework type: {type(like)}") - - -def sum[T: ArrayLike]( # noqa: A001 - array: T, - dim: int | tuple[int, ...] | None = None, - keepdims: bool = False, -) -> T: - """ - Sum elements of an array. - - Args: - array (ArrayLike): Input array (NumPy, PyTorch, TensorFlow, JAX) or nested container (list, tuple). - dim (int | tuple[int, ...] | None): Dimension or dimensions along which to sum. - If None, sums over flattened array. - keepdims (bool): If True, retains reduced dimensions with length 1. - - Returns: - ArrayLike: Summed value in the same framework type as the input. - - Raises: - TypeError: if the framework type of `array` is unsupported. - - """ - if isinstance(array, np.ndarray): - return cast("T", np.sum(array, axis=dim, keepdims=keepdims)) - if torch and isinstance(array, torch.Tensor): - return cast("T", torch.sum(array, dim=dim, keepdim=keepdims)) - if tf and isinstance(array, tf.Tensor): - return cast("T", tf.reduce_sum(array, axis=dim, keepdims=keepdims)) - if jnp and isinstance(array, jnp.ndarray): - return cast("T", jnp.sum(array, axis=dim, keepdims=keepdims)) - if isinstance(array, (list, tuple)): - np_array = to_numpy(array) - summed = np.sum(np_array, axis=dim, keepdims=keepdims).tolist() - return cast("T", type(array)(summed if isinstance(summed, list) else [summed])) - - raise TypeError(f"Unsupported framework type: {type(array)}") - - -def mean[T: ArrayLike]( - array: T, - dim: int | tuple[int, ...] | None = None, - keepdims: bool = False, -) -> T: - """ - Compute mean of array elements. - - Args: - array (ArrayLike): Input array (NumPy, PyTorch, TensorFlow, JAX) or nested container (list, tuple). - dim (int | tuple[int, ...] | None): Dimension or dimensions along which to compute mean. - If None, computes mean over flattened array. - keepdims (bool): If True, retains reduced dimensions with length 1. - - Returns: - ArrayLike: Mean value in the same framework type as the input. - - Raises: - TypeError: if the framework type of `array` is unsupported. - - """ - if isinstance(array, np.ndarray): - return cast("T", np.mean(array, axis=dim, keepdims=keepdims)) - if torch and isinstance(array, torch.Tensor): - return cast("T", torch.mean(array, dim=dim, keepdim=keepdims)) - if tf and isinstance(array, tf.Tensor): - return cast("T", tf.reduce_mean(array, axis=dim, keepdims=keepdims)) - if jnp and isinstance(array, jnp.ndarray): - return cast("T", jnp.mean(array, axis=dim, keepdims=keepdims)) - if isinstance(array, (list, tuple)): - np_array = to_numpy(array) - meaned = np.mean(np_array, axis=dim, keepdims=keepdims).tolist() - return cast("T", type(array)(meaned if isinstance(meaned, list) else [meaned])) - - raise TypeError(f"Unsupported framework type: {type(array)}") - - -def min[T: ArrayLike]( # noqa: A001 - array: T, - dim: int | tuple[int, ...] | None = None, - keepdims: bool = False, -) -> T: - """ - Compute minimum of array elements. - - Args: - array (ArrayLike): Input array (NumPy, PyTorch, TensorFlow, JAX) or nested container (list, tuple). - dim (int | tuple[int, ...] | None): Dimension or dimensions along which to compute minimum. - If None, finds minimum over flattened array. - keepdims (bool): If True, retains reduced dimensions with length 1. - - Returns: - ArrayLike: Minimum value in the same framework type as the input. - - Raises: - TypeError: if the framework type of `array` is unsupported. - - """ - if isinstance(array, np.ndarray): - return cast("T", np.min(array, axis=dim, keepdims=keepdims)) - if torch and isinstance(array, torch.Tensor): - return cast("T", torch.amin(array, dim=dim, keepdim=keepdims)) - if tf and isinstance(array, tf.Tensor): - return cast("T", tf.reduce_min(array, axis=dim, keepdims=keepdims)) - if jnp and isinstance(array, jnp.ndarray): - return cast("T", jnp.min(array, axis=dim, keepdims=keepdims)) - if isinstance(array, (list, tuple)): - np_array = to_numpy(array) - mined = np.min(np_array, axis=dim, keepdims=keepdims).tolist() - return cast("T", type(array)(mined if isinstance(mined, list) else [mined])) - - raise TypeError(f"Unsupported framework type: {type(array)}") - - -def max[T: ArrayLike]( # noqa: A001 - array: T, - dim: int | tuple[int, ...] | None = None, - keepdims: bool = False, -) -> T: - """ - Compute maximum of array elements. - - Args: - array (ArrayLike): Input array (NumPy, PyTorch, TensorFlow, JAX) or nested container (list, tuple). - dim (int | tuple[int, ...] | None): Dimension or dimensions along which to compute maximum. - If None, finds maximum over flattened array. - keepdims (bool): If True, retains reduced dimensions with length 1. - - Returns: - ArrayLike: Maximum value in the same framework type as the input. - - Raises: - TypeError: if the framework type of `array` is unsupported. - - """ - if isinstance(array, np.ndarray): - return cast("T", np.max(array, axis=dim, keepdims=keepdims)) - if torch and isinstance(array, torch.Tensor): - return cast("T", torch.amax(array, dim=dim, keepdim=keepdims)) - if tf and isinstance(array, tf.Tensor): - return cast("T", tf.reduce_max(array, axis=dim, keepdims=keepdims)) - if jnp and isinstance(array, jnp.ndarray): - return cast("T", jnp.max(array, axis=dim, keepdims=keepdims)) - if isinstance(array, (list, tuple)): - np_array = to_numpy(array) - maxed = np.max(np_array, axis=dim, keepdims=keepdims).tolist() - return cast("T", type(array)(maxed if isinstance(maxed, list) else [maxed])) - - raise TypeError(f"Unsupported framework type: {type(array)}") - - -def argmax[T: ArrayLike](array: T, dim: int | None = None, keepdims: bool = False) -> T: - """ - Compute index of maximum value. - - Args: - array (ArrayLike): Input array (NumPy, PyTorch, TensorFlow, JAX) or nested container (list, tuple). - dim (int | None): Dimension along which to find maximum. If None, finds maximum over flattened array. - keepdims (bool): If True, retains reduced dimensions with length 1. - - Returns: - ArrayLike: Indices of maximum values in the same framework type as the input. - - Raises: - TypeError: if the framework type of `array` is unsupported. - - """ - if isinstance(array, np.ndarray): - return cast("T", np.argmax(array, axis=dim, keepdims=keepdims)) - if torch and isinstance(array, torch.Tensor): - return cast("T", torch.argmax(array, dim=dim, keepdim=keepdims)) - if tf and isinstance(array, tf.Tensor): - ret = None - if dim is None: - # TensorFlow's argmax does not support dim=None directly - dims = array.ndim if array.ndim is not None else 0 - array = tf.reshape(array, [-1]) - ret = ( - cast("T", tf.math.argmax(array, axis=0)) - if not keepdims - else cast("T", tf.reshape(tf.math.argmax(array, axis=0), [1] * dims)) - ) - else: - ret = ( - cast("T", tf.math.argmax(array, axis=dim)) - if not keepdims - else cast("T", tf.expand_dims(tf.math.argmax(array, axis=dim), axis=dim)) - ) - return ret - if jnp and isinstance(array, jnp.ndarray): - return cast("T", jnp.argmax(array, axis=dim, keepdims=keepdims)) - if isinstance(array, (list, tuple)): - np_array = to_numpy(array) - argmaxed = np.argmax(np_array, axis=dim, keepdims=keepdims).tolist() - return cast( - "T", - type(array)(argmaxed if isinstance(argmaxed, list) else [argmaxed]), - ) - - raise TypeError(f"Unsupported framework type: {type(array)}") - - -def argmin[T: ArrayLike](array: T, dim: int | None = None, keepdims: bool = False) -> T: - """ - Compute index of minimum value. - - Args: - array (ArrayLike): Input array (NumPy, PyTorch, TensorFlow, JAX) or nested container (list, tuple). - dim (int | None): Dimension along which to find minimum. If None, finds minimum over flattened array. - keepdims (bool): If True, retains reduced dimensions with length 1. - - Returns: - ArrayLike: Indices of minimum values in the same framework type as the input. - - Raises: - TypeError: if the framework type of `array` is unsupported. - - """ - if isinstance(array, np.ndarray): - return cast("T", np.argmin(array, axis=dim, keepdims=keepdims)) - if torch and isinstance(array, torch.Tensor): - return cast("T", torch.argmin(array, dim=dim, keepdim=keepdims)) - if tf and isinstance(array, tf.Tensor): - ret = None - if dim is None: - # TensorFlow's argmin does not support dim=None directly - dims = array.ndim if array.ndim is not None else 0 - array = tf.reshape(array, [-1]) - ret = ( - cast("T", tf.math.argmin(array, axis=0)) - if not keepdims - else cast("T", tf.reshape(tf.math.argmin(array, axis=0), [1] * dims)) - ) - else: - ret = ( - cast("T", tf.math.argmin(array, axis=dim)) - if not keepdims - else cast("T", tf.expand_dims(tf.math.argmin(array, axis=dim), axis=dim)) - ) - return ret - if jnp and isinstance(array, jnp.ndarray): - return cast("T", jnp.argmin(array, axis=dim, keepdims=keepdims)) - if isinstance(array, (list, tuple)): - np_array = to_numpy(array) - argmined = np.argmin(np_array, axis=dim, keepdims=keepdims).tolist() - return cast( - "T", - type(array)(argmined if isinstance(argmined, list) else [argmined]), - ) - - raise TypeError(f"Unsupported framework type: {type(array)}") - - -def copy[T: ArrayLike](array: T) -> T: - """ - Create a copy of the input array. - - Args: - array (ArrayLike): Input array (NumPy, PyTorch, TensorFlow, JAX) or nested container (list, tuple). - - Returns: - ArrayLike: A copy of the input array in the same framework type. - - """ - if isinstance(array, np.ndarray): - return cast("T", np.copy(array)) - if torch and isinstance(array, torch.Tensor): - return cast("T", torch.clone(array)) - if tf and isinstance(array, tf.Tensor): - return cast("T", tf.identity(array)) - if jnp and isinstance(array, jnp.ndarray): - return cast("T", jnp.array(array, copy=True)) - return deepcopy(array) - - -def stack[T: ArrayLike](arrays: Sequence[T], dim: int = 0) -> T: - """ - Stack a sequence of arrays along a new dimension. - - Args: - arrays (Sequence[ArrayLike]): Sequence of input arrays (NumPy, PyTorch, TensorFlow, JAX) - or nested containers (list, tuple). - dim (int): Dimension along which to stack the arrays. - - Returns: - ArrayLike: Stacked array in the same framework type as the inputs. - - Raises: - TypeError: if the framework type of the input arrays is unsupported. - - """ - if isinstance(arrays[0], np.ndarray): - return cast("T", np.stack(arrays, axis=dim)) - if torch and isinstance(arrays[0], torch.Tensor): - return cast("T", torch.stack(arrays, dim=dim)) - if tf and isinstance(arrays[0], tf.Tensor): - return cast("T", tf.stack(arrays, axis=dim)) - if jnp and isinstance(arrays[0], jnp.ndarray): - return cast("T", jnp.stack(arrays, axis=dim)) - if isinstance(arrays[0], (list, tuple)): - return cast("T", type(arrays[0])(np.stack(arrays, axis=dim).tolist())) - - raise TypeError(f"Unsupported framework type: {type(arrays[0])}") - - -def reshape[T: ArrayLike](array: T, shape: tuple[int, ...]) -> T: - """ - Reshape an array to the specified shape. - - Args: - array (ArrayLike): Input array (NumPy, PyTorch, TensorFlow, JAX) or nested container (list, tuple). - shape (tuple[int, ...]): Desired shape for the output array. - - Returns: - ArrayLike: Reshaped array in the same framework type as the input. - - Raises: - TypeError: if the framework type of `array` is unsupported. - - """ - if isinstance(array, np.ndarray): - return cast("T", np.reshape(array, shape)) - if torch and isinstance(array, torch.Tensor): - return cast("T", torch.reshape(array, shape)) - if tf and isinstance(array, tf.Tensor): - return cast("T", tf.reshape(array, shape)) - if jnp and isinstance(array, jnp.ndarray): - return cast("T", jnp.reshape(array, shape)) - if isinstance(array, (list, tuple)): - np_array = to_numpy(array) - reshaped = np.reshape(np_array, shape) - return cast("T", type(array)(reshaped.tolist())) - - raise TypeError(f"Unsupported framework type: {type(array)}") - - -def zeros_like[T: ArrayLike](array: T) -> T: - """ - Create an array of zeros with the same shape and type as the input. - - Args: - array (ArrayLike): Input array (NumPy, PyTorch, TensorFlow, JAX) or nested container (list, tuple). - - Returns: - ArrayLike: Array of zeros in the same framework type as the input. - - Raises: - TypeError: if the framework type of `array` is unsupported. - - """ - if isinstance(array, np.ndarray): - return np.zeros_like(array) - if torch and isinstance(array, torch.Tensor): - return cast("T", torch.zeros_like(array)) - if tf and isinstance(array, tf.Tensor): - return cast("T", tf.zeros_like(array)) - if jnp and isinstance(array, jnp.ndarray): - return cast("T", jnp.zeros_like(array)) - if isinstance(array, (list, tuple)): - np_array = to_numpy(array) - return cast("T", type(array)(np.zeros_like(np_array).tolist())) - - raise TypeError(f"Unsupported framework type: {type(array)}") - - -def ones_like[T: ArrayLike](array: T) -> T: - """ - Create an array of ones with the same shape and type as the input. - - Args: - array (ArrayLike): Input array (NumPy, PyTorch, TensorFlow, JAX) or nested container (list, tuple). - - Returns: - ArrayLike: Array of ones in the same framework type as the input. - - Raises: - TypeError: if the framework type of `array` is unsupported. - - """ - if isinstance(array, np.ndarray): - return np.ones_like(array) - if torch and isinstance(array, torch.Tensor): - return cast("T", torch.ones_like(array)) - if tf and isinstance(array, tf.Tensor): - return cast("T", tf.ones_like(array)) - if jnp and isinstance(array, jnp.ndarray): - return cast("T", jnp.ones_like(array)) - if isinstance(array, (list, tuple)): - np_array = to_numpy(array) - return cast("T", type(array)(np.ones_like(np_array).tolist())) - - raise TypeError(f"Unsupported framework type: {type(array)}") - - -def rand_like[T: ArrayLike](array: T, low: float = 0.0, high: float = 1.0) -> T: - """ - Create an array of random values with the same shape and type as the input. - - Values are drawn uniformly from [low, high). - - Args: - array (ArrayLike): Input array (NumPy, PyTorch, TensorFlow, JAX) or nested container (list, tuple). - low (float): Lower bound of the uniform distribution. - high (float): Upper bound of the uniform distribution. - - Returns: - ArrayLike: Array of random values in the same framework type as the input. - - Raises: - TypeError: if the framework type of `array` is unsupported. - - """ - if isinstance(array, np.ndarray): - random_array = np.random.default_rng().uniform( - low=low, - high=high, - size=array.shape, - ) - random_array = random_array.astype(array.dtype) if isinstance(random_array, np.ndarray) else random_array - return cast("T", random_array) - if torch and isinstance(array, torch.Tensor): - return cast("T", (high - low) * torch.rand_like(array) + low) - if tf and isinstance(array, tf.Tensor): - return cast( - "T", - tf.random.uniform( - tf.shape(array), - dtype=array.dtype, - minval=low, - maxval=high, - ), - ) - if jnp and jax_random and isinstance(array, jnp.ndarray): - return cast( - "T", - jax_random.uniform( - jax_random.key(random.randint(0, 2**32 - 1)), - shape=array.shape, - dtype=array.dtype, - minval=low, - maxval=high, - ), - ) - if isinstance(array, (list, tuple)): - np_array = to_numpy(array) - np_random_array = np.random.default_rng().uniform( - low=low, - high=high, - size=np_array.shape, - ) - np_random_array = ( - np_random_array.astype(np_array.dtype).tolist() - if isinstance(np_random_array, np.ndarray) - else np_random_array - ) - return cast("T", type(array)(np_random_array)) - - raise TypeError(f"Unsupported framework type: {type(array)}") - - -def randn_like[T: ArrayLike](array: T, mean: float = 0.0, std: float = 1.0) -> T: - """ - Create an array of random values with the same shape and type as the input. - - Values are drawn from a normal distribution with mean `mean` and standard deviation `std`. - - Args: - array (ArrayLike): Input array (NumPy, PyTorch, TensorFlow, JAX) or nested container (list, tuple). - mean (float): Mean of the normal distribution. - std (float): Standard deviation of the normal distribution. - - Returns: - ArrayLike: Array of random values in the same framework type as the input. - - Raises: - TypeError: if the framework type of `array` is unsupported. - - """ - if isinstance(array, np.ndarray): - random_array = np.random.default_rng().normal( - loc=mean, - scale=std, - size=array.shape, - ) - random_array = random_array.astype(array.dtype) if isinstance(random_array, np.ndarray) else random_array - return cast("T", random_array) - if torch and isinstance(array, torch.Tensor): - return cast( - "T", - torch.normal( - mean=mean, - std=std, - size=array.shape, - dtype=array.dtype, - device=array.device, - ), - ) - if tf and isinstance(array, tf.Tensor): - shape = tf.shape(array) - return cast("T", tf.random.normal(shape, mean=mean, stddev=std, dtype=array.dtype)) - if jnp and jax_random and isinstance(array, jnp.ndarray): - return cast( - "T", - mean - + std - * jax_random.normal( - jax_random.key(random.randint(0, 2**32 - 1)), - shape=array.shape, - dtype=array.dtype, - ), - ) - if isinstance(array, (list, tuple)): - np_array = to_numpy(array) - np_random_array = np.random.default_rng().normal(loc=mean, scale=std, size=np_array.shape) - np_random_array = ( - np_random_array.astype(np_array.dtype).tolist() - if isinstance(np_random_array, np.ndarray) - else np_random_array - ) - return cast("T", type(array)(np_random_array)) - - raise TypeError(f"Unsupported framework type: {type(array)}") - - -def eye_like[T: ArrayLike](array: T) -> T: - """ - Create an identity matrix with the same shape as the input. - - Args: - array (ArrayLike): Input array (NumPy, PyTorch, TensorFlow, JAX) or nested container (list, tuple). - - Returns: - ArrayLike: Identity matrix in the same framework type as the input. - - Raises: - TypeError: if the framework type of `array` is unsupported. - - """ - if isinstance(array, np.ndarray): - return cast("T", np.eye(*array.shape[-2:], dtype=array.dtype, device=array.device)) - if torch and isinstance(array, torch.Tensor): - return cast( - "T", - torch.eye(*array.shape[-2:], dtype=array.dtype, device=array.device), - ) - if tf and isinstance(array, tf.Tensor): - return cast("T", tf.eye(*array.shape[-2:], dtype=array.dtype)) - if jnp and isinstance(array, jnp.ndarray): - return cast("T", jnp.eye(*array.shape[-2:], dtype=array.dtype, device=array.device)) - if isinstance(array, (list, tuple)): - np_array = to_numpy(array) - eye_array = np.eye(*np_array.shape[-2:]) - return cast("T", type(array)(eye_array.tolist())) - - raise TypeError(f"Unsupported framework type: {type(array)}") - - -def transpose[T: ArrayLike](array: T, dim: tuple[int, ...] | None = None) -> T: - """ - Transpose an array. - - Args: - array (ArrayLike): Input array (NumPy, PyTorch, TensorFlow, JAX) or nested container (list, tuple). - dim (tuple[int, ...] | None): Desired dim order. If None, reverses the dimensions. - - Returns: - ArrayLike: Transposed array in the same framework type as the input. - - Raises: - TypeError: if the framework type of `array` is unsupported. - - """ - if isinstance(array, np.ndarray): - return cast("T", np.transpose(array, axes=dim)) - if torch and isinstance(array, torch.Tensor): - # Handle None case for PyTorch - return ( - cast("T", torch.permute(array, dims=dim)) - if dim - else cast("T", torch.permute(array, dims=list(reversed(range(array.ndim))))) - ) - if tf and isinstance(array, tf.Tensor): - return cast("T", tf.transpose(array, perm=dim)) - if jnp and isinstance(array, jnp.ndarray): - return cast("T", jnp.transpose(array, axes=dim)) - if isinstance(array, (list, tuple)): - np_array = to_numpy(array) - transposed = np.transpose(np_array, axes=dim) - return cast("T", type(array)(transposed.tolist())) - - raise TypeError(f"Unsupported framework type: {type(array)}") diff --git a/decent_bench/utils/interoperability/__init__.py b/decent_bench/utils/interoperability/__init__.py new file mode 100644 index 0000000..1d4940a --- /dev/null +++ b/decent_bench/utils/interoperability/__init__.py @@ -0,0 +1,102 @@ +""" +Utilities for operating on arrays from different deep learning and linear algebra frameworks. + +Mirrors NumPy's functionality for interoperability across frameworks. +""" + +from __future__ import annotations + +from . import _ext as ext +from ._decorators import autodecorate_cost_method +from ._functions import ( + argmax, + argmin, + astype, + copy, + eye, + eye_like, + get_item, + max, # noqa: A004 + mean, + min, # noqa: A004 + norm, + ones_like, + rand_like, + randn_like, + reshape, + set_item, + shape, + stack, + sum, # noqa: A004 + to_array, + to_array_like, + to_jax, + to_numpy, + to_tensorflow, + to_torch, + transpose, + zeros, + zeros_like, +) +from ._helpers import framework_device_of_array +from ._operators import ( + absolute, + add, + div, + dot, + matmul, + mul, + negative, + power, + sqrt, + sub, +) + +__all__ = [ # noqa: RUF022 + # From _functions + "argmax", + "argmin", + "astype", + "copy", + "eye", + "eye_like", + "get_item", + "max", + "mean", + "min", + "norm", + "ones_like", + "rand_like", + "randn_like", + "reshape", + "set_item", + "shape", + "stack", + "sum", + "to_array", + "to_array_like", + "to_numpy", + "to_torch", + "to_tensorflow", + "to_jax", + "transpose", + "zeros", + "zeros_like", + # From _operators + "absolute", + "add", + "div", + "dot", + "matmul", + "mul", + "negative", + "power", + "sqrt", + "sub", + # From _helpers + "framework_device_of_array", + # From _decorators + "autodecorate_cost_method", + # Extensions + "ext", +] diff --git a/decent_bench/utils/interoperability/_decorators.py b/decent_bench/utils/interoperability/_decorators.py new file mode 100644 index 0000000..87b597f --- /dev/null +++ b/decent_bench/utils/interoperability/_decorators.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +from collections.abc import Callable +from functools import wraps +from typing import TYPE_CHECKING, Any, TypeVar, cast + +from decent_bench.utils.array import Array +from decent_bench.utils.logger import LOGGER +from decent_bench.utils.types import SupportedDevices, SupportedFrameworks + +from ._functions import to_array_like, to_jax, to_numpy, to_tensorflow, to_torch +from ._helpers import framework_device_of_array + +if TYPE_CHECKING: + from decent_bench.costs import Cost + +T = TypeVar("T", bound=Callable[..., Any]) +"""A generic callable type variable.""" + + +def _get_converter(framework: SupportedFrameworks) -> Callable[[Array | Any, SupportedDevices], Any]: + if framework == SupportedFrameworks.NUMPY: + return to_numpy + if framework == SupportedFrameworks.TORCH: + return to_torch + if framework == SupportedFrameworks.TENSORFLOW: + return to_tensorflow + if framework == SupportedFrameworks.JAX: + return to_jax + + raise ValueError(f"Unsupported framework: {framework}") + + +def autodecorate_cost_method[T: Callable[..., Any]](superclass_method: T) -> Callable[[Callable[..., Any]], T]: + """ + Decorate Cost methods to automatically convert :class:`~decent_bench.utils.array.Array` args and return types. + + It automatically converts input :class:`~decent_bench.utils.array.Array` arguments + to the cost's framework-specific array type and wraps the output based on the + superclass method's return type annotation. + + Args: + superclass_method: The method from the superclass (e.g., `Cost.function`) that is being overridden. + + Note: + * Only arguments that are instances of :class:`~decent_bench.utils.array.Array` are converted. + Other types are passed through unchanged. + * The first input argument of the decorated function must be ``x``. + This is to determine the target array type for output conversion. Otherwise a :class:`ValueError` is raised. + * Emits a warning if an input array's framework differs from the cost's framework. + This may lead to unexpected behavior or performance issues. + + """ + + def decorator(func: Callable[..., Any]) -> T: + # Determine the expected return type from the superclass method's annotations. + try: + return_type_annotation = superclass_method.__annotations__["return"] + except (AttributeError, KeyError): + return_type_annotation = None + + @wraps(func) + def wrapper(self: Cost, *args: Any, **kwargs: Any) -> Any: # noqa: ANN401 + converter = _get_converter(self.framework) + + if len(args) > 0: + x_like = args[0] + elif "x" in kwargs: + x_like = kwargs["x"] + else: + raise ValueError("First argument must be 'x' for autodecorate_cost_method to work.") + + new_args = [] + for arg in args: + if isinstance(arg, Array): + framework, _ = framework_device_of_array(arg) + if framework != self.framework: + LOGGER.warning( + f"Converting array from framework {framework} to {self.framework}" + f" in method {func.__name__}. This may lead to unexpected behavior or performance issues." + ) + new_args.append(converter(arg, self.device)) + else: + new_args.append(arg) + + new_kwargs = {} + for key, value in kwargs.items(): + if isinstance(value, Array): + framework, _ = framework_device_of_array(value) + if framework != self.framework: + LOGGER.warning( + f"Converting array from framework {framework} to {self.framework}" + f" in method {func.__name__}. This may lead to unexpected behavior or performance issues." + ) + new_kwargs[key] = converter(value, self.device) + else: + new_kwargs[key] = value + + result = func(self, *new_args, **new_kwargs) + + if return_type_annotation is Array: + return to_array_like(result, x_like) + + return result + + # Cast the wrapper to the type of the superclass method. + # This tells mypy that the decorated method is compatible with the superclass. + return cast("T", wrapper) + + return decorator diff --git a/decent_bench/utils/interoperability/_ext.py b/decent_bench/utils/interoperability/_ext.py new file mode 100644 index 0000000..b6138b2 --- /dev/null +++ b/decent_bench/utils/interoperability/_ext.py @@ -0,0 +1,15 @@ +from ._operators import ( + iadd, + idiv, + imul, + ipow, + isub, +) + +__all__ = [ + "iadd", + "idiv", + "imul", + "ipow", + "isub", +] diff --git a/decent_bench/utils/interoperability/_functions.py b/decent_bench/utils/interoperability/_functions.py new file mode 100644 index 0000000..adad569 --- /dev/null +++ b/decent_bench/utils/interoperability/_functions.py @@ -0,0 +1,936 @@ +from __future__ import annotations + +import contextlib +from collections.abc import Sequence +from copy import deepcopy +from typing import TYPE_CHECKING, Any, cast + +import numpy as np +from numpy.typing import NDArray + +from decent_bench.utils.array import Array +from decent_bench.utils.types import ArrayKey, SupportedArrayTypes, SupportedDevices, SupportedFrameworks + +from ._helpers import _device_literal_to_framework_device, _return_array, framework_device_of_array +from ._imports_types import ( + _jax_key, + _jnp_types, + _np_types, + _numpy_generator, + _tf_types, + _torch_types, +) + +jax = None +jnp = None +tf = None +torch = None + +with contextlib.suppress(ImportError, ModuleNotFoundError): + import torch as _torch + + torch = _torch + +with contextlib.suppress(ImportError, ModuleNotFoundError): + import tensorflow as _tf + + tf = _tf + +with contextlib.suppress(ImportError, ModuleNotFoundError): + import jax.numpy as _jnp + + jnp = _jnp + +with contextlib.suppress(ImportError, ModuleNotFoundError): + import jax as _jax + + jax = _jax + +if TYPE_CHECKING: + from jax import Array as JaxArray + from tensorflow import Tensor as TensorFlowTensor + from torch import Tensor as TorchTensor + + +def to_numpy(array: Array | SupportedArrayTypes, device: SupportedDevices = SupportedDevices.CPU) -> NDArray[Any]: # noqa: ARG001 + """ + Convert input array to a NumPy array. + + Args: + array (Array | SupportedArrayTypes): Input Array + device (SupportedDevices): Device of the input array. + + Returns: + NDArray: Converted NumPy array. + + Note: + The `device` parameter is currently not used in this function but is included for API consistency. + + """ + value = array.value if isinstance(array, Array) else array + + if isinstance(value, np.ndarray | np.generic): + return value + if torch and isinstance(value, torch.Tensor): + return cast("np.ndarray", value.cpu().numpy()) # pyright: ignore[reportAttributeAccessIssue] + if tf and isinstance(value, tf.Tensor): + return cast("np.ndarray", value.numpy()) # pyright: ignore[reportAttributeAccessIssue] + if jnp and isinstance(value, jnp.ndarray | jnp.generic): + return np.array(value) + return np.array(value) + + +def to_torch(array: Array | SupportedArrayTypes, device: SupportedDevices) -> TorchTensor: + """ + Convert input array to a PyTorch tensor. + + Args: + array (Array | SupportedArrayTypes): Input Array + device (SupportedDevices): Device of the input array. + + Returns: + torch.Tensor: Converted PyTorch tensor. + + Raises: + ImportError: if PyTorch is not installed. + + """ + if not torch: + raise ImportError("PyTorch is not installed.") + + value = array.value if isinstance(array, Array) else array + framework_device = _device_literal_to_framework_device(device, SupportedFrameworks.TORCH) + + if isinstance(value, torch.Tensor): + return cast("TorchTensor", value) + if isinstance(value, np.ndarray | np.generic): + return cast("TorchTensor", torch.from_numpy(value).to(framework_device)) + if tf and isinstance(value, tf.Tensor): + return cast("TorchTensor", torch.tensor(value.cpu(), device=framework_device)) + if jnp and isinstance(value, jnp.ndarray | jnp.generic): + return cast("TorchTensor", torch.from_numpy(np.array(value)).to(framework_device)) + return cast("TorchTensor", torch.tensor(value, device=framework_device)) + + +def to_tensorflow(array: Array | SupportedArrayTypes, device: SupportedDevices) -> TensorFlowTensor: + """ + Convert input array to a TensorFlow tensor. + + Args: + array (Array | SupportedArrayTypes): Input Array + device (SupportedDevices): Device of the input array. + + Returns: + tf.Tensor: Converted TensorFlow tensor. + + Raises: + ImportError: if TensorFlow is not installed. + + """ + if not tf: + raise ImportError("TensorFlow is not installed.") + + value = array.value if isinstance(array, Array) else array + framework_device = _device_literal_to_framework_device(device, SupportedFrameworks.TENSORFLOW) + + if isinstance(value, tf.Tensor): + with tf.device(framework_device): + return cast("TensorFlowTensor", value) + if isinstance(value, np.ndarray | np.generic): + with tf.device(framework_device): + return cast("TensorFlowTensor", tf.convert_to_tensor(value)) + if torch and isinstance(value, torch.Tensor): + with tf.device(framework_device): + return cast("TensorFlowTensor", tf.convert_to_tensor(value.cpu())) # pyright: ignore[reportArgumentType] + if jnp and isinstance(value, jnp.ndarray | jnp.generic): + with tf.device(framework_device): + return cast("TensorFlowTensor", tf.convert_to_tensor(value)) # pyright: ignore[reportArgumentType] + with tf.device(framework_device): + return cast("TensorFlowTensor", tf.convert_to_tensor(value)) # pyright: ignore[reportArgumentType] + + +def to_jax(array: Array | SupportedArrayTypes, device: SupportedDevices) -> JaxArray: + """ + Convert input array to a JAX array. + + Args: + array (Array | SupportedArrayTypes): Input Array + device (SupportedDevices): Device of the input array. + + Returns: + jax.Array: Converted JAX array. + + Raises: + ImportError: if JAX is not installed. + + """ + if not jnp: + raise ImportError("JAX is not installed.") + + value = array.value if isinstance(array, Array) else array + framework_device = _device_literal_to_framework_device(device, SupportedFrameworks.JAX) + + if isinstance(value, jnp.ndarray | jnp.generic): + return cast("JaxArray", value.to_device(framework_device)) + if isinstance(value, np.ndarray | np.generic): + return cast("JaxArray", jnp.array(value, device=framework_device)) + if torch and isinstance(value, torch.Tensor): + return cast("JaxArray", jnp.array(value, device=framework_device)) + if tf and isinstance(value, tf.Tensor): + return cast("JaxArray", jnp.array(value, device=framework_device)) + return cast("JaxArray", jnp.array(value, device=framework_device)) + + +def to_array( + array: Array | SupportedArrayTypes, + framework: SupportedFrameworks, + device: SupportedDevices, +) -> Array: + """ + Convert an array to the specified framework type. + + See :func:`decent_bench.utils.interoperability.to_array_like` if you want to convert an array to match + the framework and device of another array. + + Args: + array (Array | SupportedArrayTypes): Input array. + framework (SupportedFrameworks): Target framework type (e.g., "torch", "tf"). + device (SupportedDevices): Target device ("cpu" or "gpu"). + + Returns: + Array: Converted array in the specified framework type. + + Raises: + TypeError: if the framework type of `framework` is unsupported. + + """ + if framework == SupportedFrameworks.NUMPY: + return _return_array(to_numpy(array, device)) + if torch and framework == SupportedFrameworks.TORCH: + return _return_array(to_torch(array, device)) + if tf and framework == SupportedFrameworks.TENSORFLOW: + return _return_array(to_tensorflow(array, device)) + if jnp and framework == SupportedFrameworks.JAX: + return _return_array(to_jax(array, device)) + + raise TypeError(f"Unsupported framework type: {framework}") + + +def to_array_like(array: Array | SupportedArrayTypes, like: Array) -> Array: + """ + Convert an array to the framework/device of `like`. + + Args: + array (Array | SupportedArrayTypes): Input array. + like (Array): Array whose framework and device to match. + + Returns: + Array: Converted array in the specified framework type. + + """ + framework, device = framework_device_of_array(like) + + return to_array(array, framework, device) + + +def sum( # noqa: A001 + array: Array, + dim: int | tuple[int, ...] | None = None, + keepdims: bool = False, +) -> Array: + """ + Sum elements of an array. + + Args: + array (Array): Input array. + dim (int | tuple[int, ...] | None): Dimension or dimensions along which to sum. + If None, sums over flattened array. + keepdims (bool): If True, retains reduced dimensions with length 1. + + Returns: + Array: Summed value in the same framework type as the input. + + Raises: + TypeError: if the framework type of `array` is unsupported. + + """ + value = array.value if isinstance(array, Array) else array + + if isinstance(value, np.ndarray | np.generic): + return _return_array(np.sum(value, axis=dim, keepdims=keepdims)) + if torch and isinstance(value, torch.Tensor): + return _return_array(torch.sum(value, dim=dim, keepdim=keepdims)) + if tf and isinstance(value, tf.Tensor): + return _return_array(tf.reduce_sum(value, axis=dim, keepdims=keepdims)) + if jnp and isinstance(value, jnp.ndarray | jnp.generic): + return _return_array(jnp.sum(value, axis=dim, keepdims=keepdims)) + + raise TypeError(f"Unsupported framework type: {type(value)}") + + +def mean( + array: Array, + dim: int | tuple[int, ...] | None = None, + keepdims: bool = False, +) -> Array: + """ + Compute mean of array elements. + + Args: + array (Array): Input array. + dim (int | tuple[int, ...] | None): Dimension or dimensions along which to compute the mean. + If None, computes mean of flattened array. + keepdims (bool): If True, retains reduced dimensions with length 1. + + Returns: + Array: Mean value in the same framework type as the input. + + Raises: + TypeError: if the framework type of `array` is unsupported. + + """ + value = array.value if isinstance(array, Array) else array + + if isinstance(value, np.ndarray | np.generic): + return _return_array(np.mean(value, axis=dim, keepdims=keepdims)) + if torch and isinstance(value, torch.Tensor): + return _return_array(torch.mean(value, dim=dim, keepdim=keepdims)) + if tf and isinstance(value, tf.Tensor): + return _return_array(tf.reduce_mean(value, axis=dim, keepdims=keepdims)) + if jnp and isinstance(value, jnp.ndarray | jnp.generic): + return _return_array(jnp.mean(value, axis=dim, keepdims=keepdims)) + + raise TypeError(f"Unsupported framework type: {type(value)}") + + +def min( # noqa: A001 + array: Array, + dim: int | tuple[int, ...] | None = None, + keepdims: bool = False, +) -> Array: + """ + Compute minimum of array elements. + + Args: + array (Array): Input array. + dim (int | tuple[int, ...] | None): Dimension or dimensions along which to compute minimum. + If None, finds minimum over flattened array. + keepdims (bool): If True, retains reduced dimensions with length 1. + + Returns: + Array: Minimum value in the same framework type as the input. + + Raises: + TypeError: if the framework type of `array` is unsupported. + + """ + value = array.value if isinstance(array, Array) else array + + if isinstance(value, np.ndarray | np.generic): + return _return_array(np.min(value, axis=dim, keepdims=keepdims)) + if torch and isinstance(value, torch.Tensor): + return _return_array(torch.amin(value, dim=dim, keepdim=keepdims)) # pyright: ignore[reportArgumentType] + if tf and isinstance(value, tf.Tensor): + return _return_array(tf.reduce_min(value, axis=dim, keepdims=keepdims)) + if jnp and isinstance(value, jnp.ndarray | jnp.generic): + return _return_array(jnp.min(value, axis=dim, keepdims=keepdims)) + + raise TypeError(f"Unsupported framework type: {type(value)}") + + +def max( # noqa: A001 + array: Array, + dim: int | tuple[int, ...] | None = None, + keepdims: bool = False, +) -> Array: + """ + Compute maximum of array elements. + + Args: + array (Array): Input array. + dim (int | tuple[int, ...] | None): Dimension or dimensions along which to compute maximum. + If None, finds maximum over flattened array. + keepdims (bool): If True, retains reduced dimensions with length 1. + + Returns: + Array: Maximum value in the same framework type as the input. + + Raises: + TypeError: if the framework type of `array` is unsupported. + + """ + value = array.value if isinstance(array, Array) else array + + if isinstance(value, np.ndarray | np.generic): + return _return_array(np.max(value, axis=dim, keepdims=keepdims)) + if torch and isinstance(value, torch.Tensor): + return _return_array(torch.amax(value, dim=dim, keepdim=keepdims)) # pyright: ignore[reportArgumentType] + if tf and isinstance(value, tf.Tensor): + return _return_array(tf.reduce_max(value, axis=dim, keepdims=keepdims)) + if jnp and isinstance(value, jnp.ndarray | jnp.generic): + return _return_array(jnp.max(value, axis=dim, keepdims=keepdims)) + + raise TypeError(f"Unsupported framework type: {type(value)}") + + +def argmax(array: Array, dim: int | None = None, keepdims: bool = False) -> Array: + """ + Compute index of maximum value. + + Args: + array (Array): Input array. + dim (int | None): Dimension along which to find maximum. If None, finds maximum over flattened array. + keepdims (bool): If True, retains reduced dimensions with length 1. + + Returns: + Array: Indices of maximum values in the same framework type as the input. + + Raises: + TypeError: if the framework type of `array` is unsupported. + + """ + value = array.value if isinstance(array, Array) else array + + if isinstance(value, np.ndarray | np.generic): + return _return_array(np.argmax(value, axis=dim, keepdims=keepdims)) + if torch and isinstance(value, torch.Tensor): + return _return_array(torch.argmax(value, dim=dim, keepdim=keepdims)) + if tf and isinstance(value, tf.Tensor): + if dim is None: + # TensorFlow's argmax does not support dim=None directly + dims = value.ndim if value.ndim is not None else 0 + reshaped_array = tf.reshape(value, [-1]) + amax = tf.math.argmax(reshaped_array, axis=0) + ret = _return_array(amax) if not keepdims else _return_array(tf.reshape(amax, [1] * dims)) + else: + ret = ( + _return_array(tf.math.argmax(value, axis=dim)) + if not keepdims + else _return_array(tf.expand_dims(tf.math.argmax(value, axis=dim), axis=dim)) + ) + return ret + if jnp and isinstance(value, jnp.ndarray | jnp.generic): + return _return_array(jnp.argmax(value, axis=dim, keepdims=keepdims)) + + raise TypeError(f"Unsupported framework type: {type(value)}") + + +def argmin(array: Array, dim: int | None = None, keepdims: bool = False) -> Array: + """ + Compute index of minimum value. + + Args: + array (Array): Input array. + dim (int | None): Dimension along which to find minimum. If None, finds minimum over flattened array. + keepdims (bool): If True, retains reduced dimensions with length 1. + + Returns: + Array: Indices of minimum values in the same framework type as the input. + + Raises: + TypeError: if the framework type of `array` is unsupported. + + """ + value = array.value if isinstance(array, Array) else array + + if isinstance(value, np.ndarray | np.generic): + return _return_array(np.argmin(value, axis=dim, keepdims=keepdims)) + if torch and isinstance(value, torch.Tensor): + return _return_array(torch.argmin(value, dim=dim, keepdim=keepdims)) + if tf and isinstance(value, tf.Tensor): + ret = None + if dim is None: + # TensorFlow's argmin does not support dim=None directly + dims = value.ndim if value.ndim is not None else 0 + tf_array = tf.reshape(value, [-1]) + amin = tf.math.argmin(tf_array, axis=0) + ret = _return_array(amin) if not keepdims else _return_array(tf.reshape(amin, [1] * dims)) + else: + ret = ( + _return_array(tf.math.argmin(value, axis=dim)) + if not keepdims + else _return_array(tf.expand_dims(tf.math.argmin(value, axis=dim), axis=dim)) + ) + return ret + if jnp and isinstance(value, jnp.ndarray | jnp.generic): + return _return_array(jnp.argmin(value, axis=dim, keepdims=keepdims)) + + raise TypeError(f"Unsupported framework type: {type(value)}") + + +def copy(array: Array) -> Array: + """ + Create a copy of the input array. + + Args: + array (Array): Input array. + + Returns: + Array: A copy of the input array in the same framework type. + + """ + value = array.value if isinstance(array, Array) else array + + if isinstance(value, np.ndarray | np.generic): + return _return_array(np.copy(value)) + if torch and isinstance(value, torch.Tensor): + return _return_array(torch.clone(value)) + if tf and isinstance(value, tf.Tensor): + return _return_array(tf.identity(value)) + if jnp and isinstance(value, jnp.ndarray | jnp.generic): + return _return_array(jnp.array(value, copy=True)) + return deepcopy(array) + + +def stack(arrays: Sequence[Array], dim: int = 0) -> Array: + """ + Stack a sequence of arrays along a new dimension. + + Args: + arrays (Sequence[Array]): Sequence of input arrays. + or nested containers (list, tuple). + dim (int): Dimension along which to stack the arrays. + + Returns: + Array: Stacked array in the same framework type as the inputs. + + Raises: + TypeError: if the framework type of the input arrays is unsupported. + + """ + arrs = [arr.value for arr in arrays] if isinstance(arrays[0], Array) else arrays + + if isinstance(arrs[0], np.ndarray | np.generic): + return _return_array(np.stack(arrs, axis=dim)) # pyright: ignore[reportArgumentType, reportCallIssue] + if torch and isinstance(arrs[0], torch.Tensor): + return _return_array(torch.stack(arrs, dim=dim)) # pyright: ignore[reportArgumentType] + if tf and isinstance(arrs[0], tf.Tensor): + return _return_array(tf.stack(arrs, axis=dim)) + if jnp and isinstance(arrs[0], jnp.ndarray | jnp.generic): + return _return_array(jnp.stack(arrs, axis=dim)) # pyright: ignore[reportArgumentType] + + raise TypeError(f"Unsupported framework type or mixed types: {[type(arr) for arr in arrs]}") + + +def reshape(array: Array, shape: tuple[int, ...]) -> Array: + """ + Reshape an array to the specified shape. + + Args: + array (Array): Input array. + shape (tuple[int, ...]): Desired shape for the output array. + + Returns: + Array: Reshaped array in the same framework type as the input. + + Raises: + TypeError: if the framework type of `array` is unsupported. + + """ + value = array.value if isinstance(array, Array) else array + + if isinstance(value, np.ndarray | np.generic): + return _return_array(np.reshape(value, shape)) + if torch and isinstance(value, torch.Tensor): + return _return_array(torch.reshape(value, shape)) + if tf and isinstance(value, tf.Tensor): + return _return_array(tf.reshape(value, shape)) + if jnp and isinstance(value, jnp.ndarray | jnp.generic): + return _return_array(jnp.reshape(value, shape)) + + raise TypeError(f"Unsupported framework type: {type(value)}") + + +def zeros_like(array: Array) -> Array: + """ + Create an array of zeros with the same shape and type as the input. + + Args: + array (Array): Input array. + + Returns: + Array: Array of zeros in the same framework type as the input. + + Raises: + TypeError: if the framework type of `array` is unsupported. + + """ + value = array.value if isinstance(array, Array) else array + + if isinstance(value, np.ndarray | np.generic): + return _return_array(np.zeros_like(value)) + if torch and isinstance(value, torch.Tensor): + return _return_array(torch.zeros_like(value)) + if tf and isinstance(value, tf.Tensor): + return _return_array(tf.zeros_like(value)) + if jnp and isinstance(value, jnp.ndarray | jnp.generic): + return _return_array(jnp.zeros_like(value)) + + raise TypeError(f"Unsupported framework type: {type(value)}") + + +def ones_like(array: Array) -> Array: + """ + Create an array of ones with the same shape and type as the input. + + Args: + array (Array): Input array. + + Returns: + Array: Array of ones in the same framework type as the input. + + Raises: + TypeError: if the framework type of `array` is unsupported. + + """ + value = array.value if isinstance(array, Array) else array + + if isinstance(value, np.ndarray | np.generic): + return _return_array(np.ones_like(value)) + if torch and isinstance(value, torch.Tensor): + return _return_array(torch.ones_like(value)) + if tf and isinstance(value, tf.Tensor): + return _return_array(tf.ones_like(value)) + if jnp and isinstance(value, jnp.ndarray | jnp.generic): + return _return_array(jnp.ones_like(value)) + + raise TypeError(f"Unsupported framework type: {type(value)}") + + +def rand_like(array: Array, low: float = 0.0, high: float = 1.0) -> Array: + """ + Create an array of random values with the same shape and type as the input. + + Values are drawn uniformly from [low, high). + + Args: + array (Array): Input array. + low (float): Lower bound of the uniform distribution. + high (float): Upper bound of the uniform distribution. + + Returns: + Array: Array of random values in the same framework type as the input. + + Raises: + TypeError: if the framework type of `array` is unsupported. + + """ + value = array.value if isinstance(array, Array) else array + + if isinstance(value, np.ndarray | np.generic): + random_array = _numpy_generator().uniform(low=low, high=high, size=value.shape) + return _return_array(random_array) + if torch and isinstance(value, torch.Tensor): + return _return_array((high - low) * torch.rand_like(value) + low) + if tf and isinstance(value, tf.Tensor): + return _return_array(tf.random.uniform(tf.shape(value), dtype=value.dtype, minval=low, maxval=high)) + if jnp and jax and isinstance(value, jnp.ndarray | jnp.generic): + global _jax_key + _jax_key, sub_key = jax.random.split(_jax_key) # pyright: ignore[reportArgumentType] + return _return_array(jax.random.uniform(sub_key, shape=value.shape, dtype=value.dtype, minval=low, maxval=high)) + + raise TypeError(f"Unsupported framework type: {type(value)}") + + +def randn_like(array: Array, mean: float = 0.0, std: float = 1.0) -> Array: + """ + Create an array of random values with the same shape and type as the input. + + Values are drawn from a normal distribution with mean `mean` and standard deviation `std`. + + Args: + array (Array): Input array. + mean (float): Mean of the normal distribution. + std (float): Standard deviation of the normal distribution. + + Returns: + Array: Array of random values in the same framework type as the input. + + Raises: + TypeError: if the framework type of `array` is unsupported. + + """ + value = array.value if isinstance(array, Array) else array + + if isinstance(value, np.ndarray | np.generic): + random_array = _numpy_generator().normal(loc=mean, scale=std, size=value.shape) + return _return_array(random_array) + if torch and isinstance(value, torch.Tensor): + return _return_array(torch.normal(mean=mean, std=std, size=value.shape, dtype=value.dtype, device=value.device)) + if tf and isinstance(value, tf.Tensor): + shape = tf.shape(value) + return _return_array(tf.random.normal(shape=shape, mean=mean, stddev=std, dtype=value.dtype)) + if jnp and jax and isinstance(value, jnp.ndarray | jnp.generic): + global _jax_key + _jax_key, sub_key = jax.random.split(_jax_key) # pyright: ignore[reportArgumentType] + return _return_array(mean + std * jax.random.normal(sub_key, shape=value.shape, dtype=value.dtype)) + + raise TypeError(f"Unsupported framework type: {type(value)}") + + +def eye_like(array: Array) -> Array: + """ + Create an identity matrix with the same shape as the input. + + Args: + array (Array): Input array. + + Returns: + Array: Identity matrix in the same framework type as the input. + + Raises: + TypeError: if the framework type of `array` is unsupported. + + """ + value = array.value if isinstance(array, Array) else array + + if isinstance(value, np.ndarray | np.generic): + return _return_array(np.eye(*value.shape[-2:], dtype=value.dtype)) + if torch and isinstance(value, torch.Tensor): + return _return_array(torch.eye(*value.shape[-2:], dtype=value.dtype, device=value.device)) + if tf and isinstance(value, tf.Tensor): + shape = tf.shape(value) + return _return_array(tf.eye(*shape[-2:], dtype=value.dtype)) + if jnp and isinstance(value, jnp.ndarray | jnp.generic): + return _return_array(jnp.eye(*value.shape[-2:], dtype=value.dtype, device=value.device)) + + raise TypeError(f"Unsupported framework type: {type(value)}") + + +def eye(n: int, framework: SupportedFrameworks, device: SupportedDevices) -> Array: + """ + Create an identity matrix of size n x n in the specified framework. + + Args: + n (int): Size of the identity matrix. + framework (SupportedFrameworks): Target framework type (e.g., "torch", "tf"). + device (SupportedDevices): Target device ("cpu" or "gpu"). + + Returns: + Array: Identity matrix in the specified framework type. + + Raises: + TypeError: if the framework type of `framework` is unsupported. + + """ + if framework == SupportedFrameworks.NUMPY: + return _return_array(np.eye(n)) + + framework_device = _device_literal_to_framework_device(device, framework) + + if torch and framework == SupportedFrameworks.TORCH: + return _return_array(torch.eye(n, device=framework_device)) + if tf and framework == SupportedFrameworks.TENSORFLOW: + with tf.device(framework_device): + return _return_array(tf.eye(n)) + if jnp and framework == SupportedFrameworks.JAX: + return _return_array(jnp.eye(n, device=framework_device)) + + raise TypeError(f"Unsupported framework type: {framework}") + + +def transpose(array: Array, dim: tuple[int, ...] | None = None) -> Array: + """ + Transpose an array. + + Args: + array (Array): Input array. + dim (tuple[int, ...] | None): Desired dim order. If None, reverses the dimensions. + + Returns: + Array: Transposed array in the same framework type as the input. + + Raises: + TypeError: if the framework type of `array` is unsupported. + + """ + value = array.value if isinstance(array, Array) else array + + if isinstance(value, np.ndarray | np.generic): + return _return_array(np.transpose(value, axes=dim)) + if torch and isinstance(value, torch.Tensor): + # Handle None case for PyTorch + return ( + _return_array(torch.permute(value, dims=dim)) + if dim + else _return_array(torch.permute(value, dims=list(reversed(range(value.ndim))))) + ) + if tf and isinstance(value, tf.Tensor): + return _return_array(tf.transpose(value, perm=dim)) + if jnp and isinstance(value, jnp.ndarray | jnp.generic): + return _return_array(jnp.transpose(value, axes=dim)) + + raise TypeError(f"Unsupported framework type: {type(value)}") + + +def shape(array: Array) -> tuple[int, ...]: + """ + Get the shape of an array. + + Args: + array (Array): Input array. + + Returns: + tuple[int, ...]: Shape of the input array. + + Raises: + TypeError: if the framework type of `array` is unsupported. + + """ + value = array.value if isinstance(array, Array) else array + + if isinstance(value, np.ndarray | np.generic): + return value.shape + if torch and isinstance(value, torch.Tensor): + return tuple(value.shape) + if tf and isinstance(value, tf.Tensor): + tf_shape = tuple(value.shape) + return cast("tuple[int, ...]", tf_shape) + if jnp and isinstance(value, jnp.ndarray | jnp.generic): + return cast("tuple[int, ...]", value.shape) + + raise TypeError(f"Unsupported framework type: {type(value)}") + + +def zeros( + shape: tuple[int, ...], + framework: SupportedFrameworks, + device: SupportedDevices, + dtype: Any | None = None, # noqa: ANN401 +) -> Array: + """ + Create a tensor of zeros. + + Args: + shape (tuple): The shape of the tensor. + framework (SupportedFrameworks): The framework to use ("numpy", "torch", "tensorflow", "jax"). + device (SupportedDevices): The device to place the tensor on. + dtype (Any | None): The data type of the tensor. Defaults to None. + + Returns: + Array: A tensor of zeros. + + """ + x = np.zeros(shape, dtype=dtype) + + return to_array(x, framework=framework, device=device) + + +def set_item( + array: Array | SupportedArrayTypes, + key: ArrayKey, + value: Array | SupportedArrayTypes, +) -> None: + """ + Set the item at the specified index of the array to the given value. + + Args: + array (Array | SupportedArrayTypes): The tensor. + key (ArrayKey): The key or index to set. + value (Array | SupportedArrayTypes): The value to set. + + Raises: + TypeError: If the type is not supported. + NotImplementedError: If the operation is not supported due to immutability. + + """ + array_value = array.value if isinstance(array, Array) else array + value_value = value.value if isinstance(value, Array) else value + + if isinstance(array_value, np.ndarray | np.generic): + array_value[key] = value_value + return + if torch and isinstance(array_value, torch.Tensor) and isinstance(value_value, _torch_types): + array_value[key] = value_value + return + if tf and isinstance(array_value, tf.Tensor) and isinstance(value_value, _tf_types): + raise NotImplementedError("Setting items in TensorFlow tensors is not supported due to immutability.") + if jnp and isinstance(array_value, jnp.ndarray | jnp.generic) and isinstance(value_value, _jnp_types): + raise NotImplementedError("Setting items in JAX arrays is not supported due to immutability.") + + raise TypeError(f"Unsupported type: {type(array_value)} with value: {type(value_value)}") + + +def get_item(array: Array, key: ArrayKey) -> Array: + """ + Get the item at the specified index of the array. + + Args: + array (Array): The tensor. + key (ArrayKey): The key or index to get. + + Returns: + Array: The item at the specified index. + + """ + value = array.value if isinstance(array, Array) else array + + return _return_array(value[key]) # type: ignore[index] + + +def astype(array: Array, dtype: type[float | int | bool]) -> float | int | bool: + """ + Cast a single-element array to a Python scalar of the specified type. + + Args: + array (Array): The tensor. + dtype (float | int | bool): The target data type. + + Returns: + float | int | bool: The casted scalar value. + + Raises: + TypeError: If the type is not supported. + + """ + value = array.value if isinstance(array, Array) else array + + if isinstance(value, _np_types): + return dtype(value.item() if hasattr(value, "item") else value) # pyright: ignore[reportAttributeAccessIssue] + if torch and isinstance(value, torch.Tensor): + return dtype(value.item()) + if tf and isinstance(value, tf.Tensor): + return dtype(to_numpy(value).item()) + if jnp and isinstance(value, _jnp_types): + return dtype(value.item()) + + raise TypeError(f"Unsupported type: {type(value)}") + + +def norm( + array: Array, + p: float = 2, + dim: int | tuple[int, ...] | None = None, + keepdims: bool = False, +) -> Array: + """ + Compute the norm of an array. + + Args: + array (Array): The tensor. + p (float): The order of the norm. + dim (int | tuple[int, ...] | None): Dimension or dimensions along which to compute the norm. + If None, computes norm over flattened array. + keepdims (bool): If True, retains reduced dimensions with length 1. + + Returns: + Array: The norm of the tensor. + + Raises: + TypeError: If the type is not supported. + + """ + value = array.value if isinstance(array, Array) else array + + if isinstance(value, np.ndarray | np.generic): + return _return_array(cast("SupportedArrayTypes", np.linalg.norm(value, ord=p, axis=dim, keepdims=keepdims))) + if torch and isinstance(value, torch.Tensor): + return _return_array(torch.linalg.norm(value, ord=p, dim=dim, keepdim=keepdims)) + if tf and isinstance(value, tf.Tensor): + if dim is None and value.ndim == 2: + dim = (-2, -1) + return _return_array(tf.norm(value, ord=p, axis=dim, keepdims=keepdims)) + if jnp and isinstance(value, jnp.ndarray | jnp.generic): + return _return_array(jnp.linalg.norm(value, ord=p, axis=dim, keepdims=keepdims)) + + raise TypeError(f"Unsupported type: {type(value)}") diff --git a/decent_bench/utils/interoperability/_helpers.py b/decent_bench/utils/interoperability/_helpers.py new file mode 100644 index 0000000..8c1c657 --- /dev/null +++ b/decent_bench/utils/interoperability/_helpers.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from decent_bench.utils.array import Array +from decent_bench.utils.types import SupportedArrayTypes, SupportedDevices, SupportedFrameworks + +from ._imports_types import _jnp_types, _np_types, _tf_types, _torch_types, jax, jnp, tf, torch + + +def _device_literal_to_framework_device(device: SupportedDevices, framework: SupportedFrameworks) -> Any: # noqa: ANN401 + """ + Convert SupportedDevices literal to framework-specific device representation. + + Args: + device (SupportedDevices): Device literal ("cpu" or "gpu"). + framework (SupportedFrameworks): Framework literal ("numpy", "torch", "tensorflow", "jax"). + + Returns: + Any: Framework-specific device representation. + + Raises: + ValueError: If the framework is unsupported. + + """ + if framework == SupportedFrameworks.NUMPY: + return device # NumPy does not have explicit device management + if torch and framework == SupportedFrameworks.TORCH: + return "cuda" if device == SupportedDevices.GPU else "cpu" + if tf and framework == SupportedFrameworks.TENSORFLOW: + return f"/{device.value}:0" + if jax and framework == SupportedFrameworks.JAX: + if device == SupportedDevices.CPU: + return jax.devices("cpu")[0] + return jax.devices("gpu")[0] + raise ValueError(f"Unsupported framework: {framework}") + + +def _return_array(array: SupportedArrayTypes) -> Array: + """ + Wrap a framework-native array in an `Array` wrapper. + + This helper standardizes return types across interoperability functions, + returning the same framework-native object at runtime, while providing a + typed `Array` during static type checking. + + Args: + array (SupportedArrayTypes): Input array (NumPy, torch, tf, jax). + + Returns: + Array: Wrapped array (type-only during static analysis; at runtime + this returns the original framework-native value). + + """ + if not TYPE_CHECKING: + return array + + return Array(array) + + +def framework_device_of_array(array: Array) -> tuple[SupportedFrameworks, SupportedDevices]: + """ + Determine the framework and device of the given Array. + + Args: + array (Array): Input array. + + Returns: + tuple[SupportedFrameworks, SupportedDevices]: Framework and device of the array. + + Raises: + TypeError: if the framework type of `array` is unsupported. + + """ + value = array.value if isinstance(array, Array) else array + + if isinstance(value, _np_types): + return SupportedFrameworks.NUMPY, SupportedDevices.CPU + if torch and isinstance(value, _torch_types): + device_type = SupportedDevices.GPU if value.device.type == "cuda" else SupportedDevices.CPU # type: ignore[union-attr] + return SupportedFrameworks.TORCH, device_type + if tf and isinstance(value, _tf_types): + device_str = value.device.lower() # type: ignore[union-attr] + device_type = SupportedDevices.GPU if "gpu" in device_str or "cuda" in device_str else SupportedDevices.CPU + return SupportedFrameworks.TENSORFLOW, device_type + if jnp and isinstance(value, _jnp_types): + backend = value.device.platform # type: ignore[union-attr] + device_type = SupportedDevices.GPU if backend == "gpu" else SupportedDevices.CPU + return SupportedFrameworks.JAX, device_type + + raise TypeError(f"Unsupported framework type: {type(value)}") diff --git a/decent_bench/utils/interoperability/_imports_types.py b/decent_bench/utils/interoperability/_imports_types.py new file mode 100644 index 0000000..94358e6 --- /dev/null +++ b/decent_bench/utils/interoperability/_imports_types.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import contextlib +import random +from functools import cache +from types import ModuleType + +import numpy as np + +jax: ModuleType | None = None +jnp: ModuleType | None = None +tf: ModuleType | None = None +torch: ModuleType | None = None + +with contextlib.suppress(ImportError, ModuleNotFoundError): + import torch as _torch + + torch = _torch + +with contextlib.suppress(ImportError, ModuleNotFoundError): + import tensorflow as _tf + + tf = _tf + +with contextlib.suppress(ImportError, ModuleNotFoundError): + import jax.numpy as _jnp + + jnp = _jnp + +with contextlib.suppress(ImportError, ModuleNotFoundError): + import jax as _jax + + jax = _jax + + +_np_types = (np.ndarray, np.generic, float, int) +_torch_types = (torch.Tensor, float, int) if torch else (float,) +_tf_types = (tf.Tensor, float, int) if tf else (float,) +_jnp_types = (jnp.ndarray, jnp.generic, float, int) if jnp else (float,) + +_jax_key = jax.random.key(random.randint(0, 2**32 - 1)) if jax else None + + +@cache +def _numpy_generator() -> np.random.Generator: + """Get a NumPy random number generator instance.""" + return np.random.default_rng() diff --git a/decent_bench/utils/interoperability/_operators.py b/decent_bench/utils/interoperability/_operators.py new file mode 100644 index 0000000..5167b94 --- /dev/null +++ b/decent_bench/utils/interoperability/_operators.py @@ -0,0 +1,483 @@ +from __future__ import annotations + +from typing import cast + +import numpy as np + +from decent_bench.utils.array import Array +from decent_bench.utils.types import SupportedArrayTypes + +from ._helpers import _return_array +from ._imports_types import _jnp_types, _np_types, _tf_types, _torch_types, jnp, tf, torch + + +def add(array1: Array | SupportedArrayTypes, array2: Array | SupportedArrayTypes) -> Array: + """ + Element-wise addition of two arrays. + + Args: + array1 (Array | SupportedArrayTypes): First input array. + array2 (Array | SupportedArrayTypes): Second input array. + + Returns: + Array: Result of element-wise addition in the same framework type as the inputs. + + Raises: + TypeError: if the framework type of the input arrays is unsupported + or if the input arrays are not of the same framework type. + + """ + value1 = array1.value if isinstance(array1, Array) else array1 + value2 = array2.value if isinstance(array2, Array) else array2 + + if isinstance(value1, _np_types): + return _return_array(value1 + value2) + if torch and isinstance(value1, _torch_types) and isinstance(value2, _torch_types): + return _return_array(torch.add(value1, value2)) + if tf and isinstance(value1, _tf_types) and isinstance(value2, _tf_types): + return _return_array(tf.add(value1, value2)) + if jnp and isinstance(value1, _jnp_types) and isinstance(value2, _jnp_types): + return _return_array(jnp.add(value1, value2)) + + raise TypeError(f"Unsupported framework type: {type(value1)} and {type(value2)}") + + +def iadd[T: Array](array1: T, array2: Array | SupportedArrayTypes) -> T: + """ + Element-wise in-place addition of two arrays. + + Args: + array1 (Array | SupportedArrayTypes): First input array. + array2 (Array | SupportedArrayTypes): Second input array. + + Returns: + Array: Result of element-wise in-place addition in the same framework type as the inputs. + + Raises: + TypeError: if the framework type of the input arrays is unsupported + or if the input arrays are not of the same framework type. + + """ + value1 = array1.value if isinstance(array1, Array) else array1 + value2 = array2.value if isinstance(array2, Array) else array2 + + if isinstance(value1, np.ndarray | np.generic): + value1 += value2 + return cast("T", _return_array(value1)) + if torch and isinstance(value1, torch.Tensor) and isinstance(value2, _torch_types): + value1 += value2 # pyright: ignore[reportOperatorIssue] + return cast("T", _return_array(value1)) + if tf and isinstance(value1, tf.Tensor) and isinstance(value2, _tf_types): + value1 += value2 # pyright: ignore[reportOperatorIssue] + return cast("T", _return_array(value1)) + if jnp and isinstance(value1, jnp.ndarray | jnp.generic) and isinstance(value2, _jnp_types): + value1 += value2 # pyright: ignore[reportOperatorIssue] + return cast("T", _return_array(value1)) + + raise TypeError(f"Unsupported framework type: {type(value1)} and {type(value2)}") + + +def sub(array1: Array | SupportedArrayTypes, array2: Array | SupportedArrayTypes) -> Array: + """ + Element-wise subtraction of two arrays. + + Args: + array1 (Array | SupportedArrayTypes): First input array. + array2 (Array | SupportedArrayTypes): Second input array. + + Returns: + Array: Result of element-wise subtraction in the same framework type as the inputs. + + Raises: + TypeError: if the framework type of the input arrays is unsupported + or if the input arrays are not of the same framework type. + + """ + value1 = array1.value if isinstance(array1, Array) else array1 + value2 = array2.value if isinstance(array2, Array) else array2 + + if isinstance(value1, _np_types): + return _return_array(value1 - value2) + if torch and isinstance(value1, _torch_types) and isinstance(value2, _torch_types): + return _return_array(torch.sub(value1, value2)) + if tf and isinstance(value1, _tf_types) and isinstance(value2, _tf_types): + return _return_array(tf.subtract(value1, value2)) + if jnp and isinstance(value1, _jnp_types) and isinstance(value2, _jnp_types): + return _return_array(jnp.subtract(value1, value2)) + + raise TypeError(f"Unsupported framework type: {type(value1)} and {type(value2)}") + + +def isub[T: Array](array1: T, array2: Array | SupportedArrayTypes) -> T: + """ + Element-wise in-place subtraction of two arrays. + + Args: + array1 (Array | SupportedArrayTypes): First input array. + array2 (Array | SupportedArrayTypes): Second input array. + + Returns: + Array: Result of element-wise in-place subtraction in the same framework type as the inputs. + + Raises: + TypeError: if the framework type of the input arrays is unsupported + or if the input arrays are not of the same framework type. + + """ + value1 = array1.value if isinstance(array1, Array) else array1 + value2 = array2.value if isinstance(array2, Array) else array2 + + if isinstance(value1, np.ndarray | np.generic): + value1 -= value2 + return cast("T", _return_array(value1)) + if torch and isinstance(value1, torch.Tensor) and isinstance(value2, _torch_types): + value1 -= value2 # pyright: ignore[reportOperatorIssue] + return cast("T", _return_array(value1)) + if tf and isinstance(value1, tf.Tensor) and isinstance(value2, _tf_types): + value1 -= value2 # pyright: ignore[reportOperatorIssue] + return cast("T", _return_array(value1)) + if jnp and isinstance(value1, jnp.ndarray | jnp.generic) and isinstance(value2, _jnp_types): + value1 -= value2 # pyright: ignore[reportOperatorIssue] + return cast("T", _return_array(value1)) + + raise TypeError(f"Unsupported framework type: {type(value1)} and {type(value2)}") + + +def mul(array1: Array | SupportedArrayTypes, array2: Array | SupportedArrayTypes) -> Array: + """ + Element-wise multiplication of two arrays. + + Args: + array1 (Array | SupportedArrayTypes): First input array. + array2 (Array | SupportedArrayTypes): Second input array. + + Returns: + Array: Result of element-wise multiplication in the same framework type as the inputs. + + Raises: + TypeError: if the framework type of the input arrays is unsupported + or if the input arrays are not of the same framework type. + + """ + value1 = array1.value if isinstance(array1, Array) else array1 + value2 = array2.value if isinstance(array2, Array) else array2 + + if isinstance(value1, _np_types): + return _return_array(value1 * value2) # pyright: ignore[reportOperatorIssue] + if torch and isinstance(value1, _torch_types) and isinstance(value2, _torch_types): + return _return_array(torch.mul(value1, value2)) + if tf and isinstance(value1, _tf_types) and isinstance(value2, _tf_types): + return _return_array(tf.multiply(value1, value2)) + if jnp and isinstance(value1, _jnp_types) and isinstance(value2, _jnp_types): + return _return_array(jnp.multiply(value1, value2)) + + raise TypeError(f"Unsupported framework type: {type(value1)} and {type(value2)}") + + +def imul[T: Array](array1: T, array2: Array | SupportedArrayTypes) -> T: + """ + Element-wise in-place multiplication of two arrays. + + Args: + array1 (Array | SupportedArrayTypes): First input array. + array2 (Array | SupportedArrayTypes): Second input array. + + Returns: + Array: Result of element-wise in-place multiplication in the same framework type as the inputs. + + Raises: + TypeError: if the framework type of the input arrays is unsupported + or if the input arrays are not of the same framework type. + + """ + value1 = array1.value if isinstance(array1, Array) else array1 + value2 = array2.value if isinstance(array2, Array) else array2 + + if isinstance(value1, np.ndarray | np.generic): + value1 *= value2 + return cast("T", _return_array(value1)) + if torch and isinstance(value1, torch.Tensor) and isinstance(value2, _torch_types): + value1 *= value2 # pyright: ignore[reportOperatorIssue] + return cast("T", _return_array(value1)) + if tf and isinstance(value1, tf.Tensor) and isinstance(value2, _tf_types): + value1 *= value2 # pyright: ignore[reportOperatorIssue] + return cast("T", _return_array(value1)) + if jnp and isinstance(value1, jnp.ndarray | jnp.generic) and isinstance(value2, _jnp_types): + value1 *= value2 # pyright: ignore[reportOperatorIssue] + return cast("T", _return_array(value1)) + + raise TypeError(f"Unsupported framework type: {type(value1)} and {type(value2)}") + + +def div(array1: Array | SupportedArrayTypes, array2: Array | SupportedArrayTypes) -> Array: + """ + Element-wise division of two arrays. + + Args: + array1 (Array | SupportedArrayTypes): First input array. + array2 (Array | SupportedArrayTypes): Second input array. + + Returns: + Array: Result of element-wise division in the same framework type as the inputs. + + Raises: + TypeError: if the framework type of the input arrays is unsupported + or if the input arrays are not of the same framework type. + + """ + value1 = array1.value if isinstance(array1, Array) else array1 + value2 = array2.value if isinstance(array2, Array) else array2 + + if isinstance(value1, _np_types): + return _return_array(value1 / value2) + if torch and isinstance(value1, _torch_types) and isinstance(value2, _torch_types): + return _return_array(torch.div(value1, value2)) + if tf and isinstance(value1, _tf_types) and isinstance(value2, _tf_types): + return _return_array(tf.divide(value1, value2)) + if jnp and isinstance(value1, _jnp_types) and isinstance(value2, _jnp_types): + return _return_array(jnp.divide(value1, value2)) + + raise TypeError(f"Unsupported framework type: {type(value1)} and {type(value2)}") + + +def idiv[T: Array](array1: T, array2: Array | SupportedArrayTypes) -> T: + """ + Element-wise in-place division of two arrays. + + Args: + array1 (Array | SupportedArrayTypes): First input array. + array2 (Array | SupportedArrayTypes): Second input array. + + Returns: + Array: Result of element-wise in-place division in the same framework type as the inputs. + + Raises: + TypeError: if the framework type of the input arrays is unsupported + or if the input arrays are not of the same framework type. + + """ + value1 = array1.value if isinstance(array1, Array) else array1 + value2 = array2.value if isinstance(array2, Array) else array2 + + if isinstance(value1, np.ndarray | np.generic): + value1 /= value2 + return cast("T", _return_array(value1)) + if torch and isinstance(value1, torch.Tensor) and isinstance(value2, _torch_types): + value1 /= value2 # pyright: ignore[reportOperatorIssue] + return cast("T", _return_array(value1)) + if tf and isinstance(value1, tf.Tensor) and isinstance(value2, _tf_types): + value1 /= value2 # pyright: ignore[reportOperatorIssue] + return cast("T", _return_array(value1)) + if jnp and isinstance(value1, jnp.ndarray | jnp.generic) and isinstance(value2, _jnp_types): + value1 /= value2 # pyright: ignore[reportOperatorIssue] + return cast("T", _return_array(value1)) + + raise TypeError(f"Unsupported framework type: {type(value1)} and {type(value2)}") + + +def matmul(array1: Array | SupportedArrayTypes, array2: Array | SupportedArrayTypes) -> Array: + """ + Matrix multiplication of two arrays. + + Args: + array1 (Array | SupportedArrayTypes): First input array. + array2 (Array | SupportedArrayTypes): Second input array. + + Returns: + Array: Result of matrix multiplication in the same framework type as the inputs. + + Raises: + TypeError: if the framework type of the input arrays is unsupported + or if the input arrays are not of the same framework type. + + """ + value1 = array1.value if isinstance(array1, Array) else array1 + value2 = array2.value if isinstance(array2, Array) else array2 + + if isinstance(value1, np.ndarray | np.generic): + return _return_array(value1 @ value2) + if torch and isinstance(value1, torch.Tensor) and isinstance(value2, torch.Tensor): + return _return_array(value1 @ value2) # pyright: ignore[reportOperatorIssue] + if tf and isinstance(value1, tf.Tensor) and isinstance(value2, tf.Tensor): + return _return_array(value1 @ value2) # pyright: ignore[reportOperatorIssue] + if jnp and isinstance(value1, jnp.ndarray | jnp.generic) and isinstance(value2, jnp.ndarray | jnp.generic): + return _return_array(value1 @ value2) # pyright: ignore[reportOperatorIssue] + + raise TypeError(f"Unsupported framework type: {type(value1)} and {type(value2)}") + + +def dot(array1: Array | SupportedArrayTypes, array2: Array | SupportedArrayTypes) -> Array: + """ + Dot product of two arrays. + + Args: + array1 (Array | SupportedArrayTypes): First input array. + array2 (Array | SupportedArrayTypes): Second input array. + + Returns: + Array: Result of the dot product in the same framework type as the inputs. + + Raises: + TypeError: if the framework type of the input arrays is unsupported + or if the input arrays are not of the same framework type. + + """ + value1 = array1.value if isinstance(array1, Array) else array1 + value2 = array2.value if isinstance(array2, Array) else array2 + + if isinstance(value1, np.ndarray | np.generic): + return _return_array(value1.dot(value2)) + if torch and isinstance(value1, torch.Tensor) and isinstance(value2, torch.Tensor): + return _return_array(value1.dot(value2)) # pyright: ignore[reportArgumentType, reportAttributeAccessIssue] + if tf and isinstance(value1, tf.Tensor) and isinstance(value2, tf.Tensor): + return _return_array(value1.dot(value2)) # pyright: ignore[reportArgumentType, reportAttributeAccessIssue] + if jnp and isinstance(value1, jnp.ndarray | jnp.generic) and isinstance(value2, jnp.ndarray | jnp.generic): + return _return_array(value1.dot(value2)) # pyright: ignore[reportArgumentType, reportAttributeAccessIssue] + + raise TypeError(f"Unsupported framework type: {type(value1)} and {type(value2)}") + + +def power(array: Array | SupportedArrayTypes, p: float) -> Array: + """ + Raise array to p power. + + Args: + array (Array | SupportedArrayTypes): The tensor. + p (float): The power. + + Returns: + Array: The result of the operation. + + Raises: + TypeError: If the type is not supported. + + """ + value = array.value if isinstance(array, Array) else array + + if isinstance(value, np.ndarray | np.generic): + return _return_array(np.power(value, p)) + if torch and isinstance(value, torch.Tensor): + return _return_array(torch.pow(value, p)) + if tf and isinstance(value, tf.Tensor): + return _return_array(tf.pow(value, p)) + if jnp and isinstance(value, jnp.ndarray | jnp.generic): + return _return_array(jnp.power(value, p)) + + raise TypeError(f"Unsupported type: {type(value)}") + + +def ipow[T: Array](array: T, p: float) -> T: + """ + Element-wise in-place power of an array. + + Args: + array (Array | SupportedArrayTypes): Input array. + p (float): The power. + + Returns: + Array: Result of element-wise in-place power in the same framework type as the inputs. + + Raises: + TypeError: if the framework type of the input arrays is unsupported + + """ + value = array.value if isinstance(array, Array) else array + + if isinstance(value, np.ndarray | np.generic): + value **= p + return cast("T", _return_array(value)) + if torch and isinstance(value, torch.Tensor): + value **= p + return cast("T", _return_array(value)) + if tf and isinstance(value, tf.Tensor): + value **= p + return cast("T", _return_array(value)) + if jnp and isinstance(value, jnp.ndarray | jnp.generic): + value **= p + return cast("T", _return_array(value)) + + raise TypeError(f"Unsupported framework type: {type(value)}") + + +def negative(array: Array | SupportedArrayTypes) -> Array: + """ + Negate array. + + Args: + array (Array | SupportedArrayTypes): The tensor. + + Returns: + Array: The negated tensor. + + Raises: + TypeError: If the type is not supported. + + """ + value = array.value if isinstance(array, Array) else array + + if isinstance(value, np.ndarray | np.generic): + return _return_array(np.negative(value)) + if torch and isinstance(value, torch.Tensor): + return _return_array(torch.neg(value)) + if tf and isinstance(value, tf.Tensor): + return _return_array(tf.negative(value)) + if jnp and isinstance(value, jnp.ndarray | jnp.generic): + return _return_array(jnp.negative(value)) + + raise TypeError(f"Unsupported type: {type(value)}") + + +def absolute(array: Array | SupportedArrayTypes) -> Array: + """ + Return the absolute value of a tensor. + + Args: + array (Array | SupportedArrayTypes): The tensor. + + Returns: + Array: The absolute value tensor. + + Raises: + TypeError: If the type is not supported. + + """ + value = array.value if isinstance(array, Array) else array + + if isinstance(value, np.ndarray | np.generic): + return _return_array(np.abs(value)) + if torch and isinstance(value, torch.Tensor): + return _return_array(torch.abs(value)) + if tf and isinstance(value, tf.Tensor): + return _return_array(tf.abs(value)) + if jnp and isinstance(value, jnp.ndarray | jnp.generic): + return _return_array(jnp.abs(value)) + + raise TypeError(f"Unsupported type: {type(value)}") + + +def sqrt(array: Array | SupportedArrayTypes) -> Array: + """ + Return the square root of a tensor. + + Args: + array (Array | SupportedArrayTypes): The tensor. + + Returns: + Array: The square root tensor. + + Raises: + TypeError: If the type is not supported. + + """ + value = array.value if isinstance(array, Array) else array + + if isinstance(value, np.ndarray | np.generic): + return _return_array(np.sqrt(value)) + if torch and isinstance(value, torch.Tensor): + return _return_array(torch.sqrt(value)) + if tf and isinstance(value, tf.Tensor): + return _return_array(tf.sqrt(value)) + if jnp and isinstance(value, jnp.ndarray | jnp.generic): + return _return_array(jnp.sqrt(value)) + + raise TypeError(f"Unsupported type: {type(value)}") diff --git a/decent_bench/utils/types.py b/decent_bench/utils/types.py new file mode 100644 index 0000000..6a8b6fd --- /dev/null +++ b/decent_bench/utils/types.py @@ -0,0 +1,46 @@ +"""Type definitions for optimization variables.""" + +from __future__ import annotations + +from enum import Enum +from typing import TYPE_CHECKING, SupportsIndex, TypeAlias, Union + +if TYPE_CHECKING: + import jax + import numpy + import tensorflow as tf + import torch + +ArrayLike: TypeAlias = Union["numpy.ndarray", "torch.Tensor", "tf.Tensor", "jax.Array"] # noqa: UP040 +""" +Type alias for array-like types supported in decent-bench, including NumPy arrays, +PyTorch tensors, TensorFlow tensors, and JAX arrays. +""" + +SupportedArrayTypes: TypeAlias = ArrayLike | float | int # noqa: UP040 +""" +Type alias for supported types for optimization variables in decent-bench, +including array-like types and scalars. +""" + +ArrayKey: TypeAlias = SupportsIndex | slice | tuple[SupportsIndex | slice, ...] # noqa: UP040 +""" +Type alias for valid keys used to index into supported array types. +Includes single indices, tuples of indices, slices, and tuples of slices. +""" + + +class SupportedFrameworks(Enum): + """Enum for supported frameworks in decent-bench.""" + + NUMPY = "numpy" + TORCH = "torch" + TENSORFLOW = "tensorflow" + JAX = "jax" + + +class SupportedDevices(Enum): + """Enum for supported devices in decent-bench.""" + + CPU = "cpu" + GPU = "gpu" diff --git a/docs/source/api/decent_bench.costs.rst b/docs/source/api/decent_bench.costs.rst index 76501d3..8326061 100644 --- a/docs/source/api/decent_bench.costs.rst +++ b/docs/source/api/decent_bench.costs.rst @@ -4,4 +4,6 @@ decent\_bench.costs .. automodule:: decent_bench.costs :members: :show-inheritance: - :undoc-members: \ No newline at end of file + :undoc-members: + :special-members: + __add__, \ No newline at end of file diff --git a/docs/source/api/decent_bench.utils.array.rst b/docs/source/api/decent_bench.utils.array.rst new file mode 100644 index 0000000..1d35a12 --- /dev/null +++ b/docs/source/api/decent_bench.utils.array.rst @@ -0,0 +1,7 @@ +decent\_bench.utils.array +========================= + +.. automodule:: decent_bench.utils.array + :members: + :show-inheritance: + :undoc-members: \ No newline at end of file diff --git a/docs/source/api/decent_bench.utils.interoperability.rst b/docs/source/api/decent_bench.utils.interoperability.rst index 06fc692..603dd8c 100644 --- a/docs/source/api/decent_bench.utils.interoperability.rst +++ b/docs/source/api/decent_bench.utils.interoperability.rst @@ -4,4 +4,6 @@ decent\_bench.utils.interoperability .. automodule:: decent_bench.utils.interoperability :members: :show-inheritance: - :undoc-members: \ No newline at end of file + :undoc-members: + :exclude-members: + ext \ No newline at end of file diff --git a/docs/source/api/decent_bench.utils.rst b/docs/source/api/decent_bench.utils.rst index 28e891a..0df8288 100644 --- a/docs/source/api/decent_bench.utils.rst +++ b/docs/source/api/decent_bench.utils.rst @@ -5,9 +5,11 @@ decent\_bench.utils .. toctree:: :maxdepth: 2 + decent_bench.utils.array + decent_bench.utils.interoperability decent_bench.utils.logger decent_bench.utils.progress_bar - decent_bench.utils.interoperability + decent_bench.utils.types .. automodule:: decent_bench.utils :members: diff --git a/docs/source/api/decent_bench.utils.types.rst b/docs/source/api/decent_bench.utils.types.rst new file mode 100644 index 0000000..26e37f2 --- /dev/null +++ b/docs/source/api/decent_bench.utils.types.rst @@ -0,0 +1,7 @@ +decent\_bench.utils.types +========================= + +.. automodule:: decent_bench.utils.types + :members: + :show-inheritance: + :undoc-members: diff --git a/docs/source/api/snippets/proximal_operator.rst b/docs/source/api/snippets/proximal_operator.rst index 5728356..52d95bd 100644 --- a/docs/source/api/snippets/proximal_operator.rst +++ b/docs/source/api/snippets/proximal_operator.rst @@ -1,5 +1,5 @@ .. math:: - \operatorname{prox}_{\rho f}(\mathbf{y}) - = \arg\min_{\mathbf{x}} \left\{ f(\mathbf{x}) + \frac{1}{2\rho} \| \mathbf{x} - \mathbf{y} \|^2 \right\} + \operatorname{prox}_{\rho f}(\mathbf{x}) + = \arg\min_{\mathbf{y}} \left\{ f(\mathbf{y}) + \frac{1}{2\rho} \| \mathbf{y} - \mathbf{x} \|^2 \right\} where :math:`\rho > 0` is the penalty and :math:`f` the cost function. diff --git a/docs/source/conf.py b/docs/source/conf.py index a06bf83..06c471d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -20,7 +20,6 @@ sys.path.insert(0, os.path.abspath("../..")) - extensions = [ "sphinx.ext.autodoc", # Expand rst automodule directives generated by `sphinx-apidoc` "sphinx.ext.intersphinx", # Link to types from external packages @@ -28,26 +27,33 @@ "sphinx.ext.viewcode", # View source code ] -autodoc_default_options = { - "special-members": "__add__", -} +autodoc_default_options = {} autodoc_member_order = "bysource" autodoc_preserve_defaults = True -autodoc_type_aliases = { - "ArrayLike": "numpy.typing.ArrayLike", -} +autodoc_type_aliases = {} nitpicky = True nitpick_ignore = [ ("py:class", "numpy.float64"), + ("py:class", "float64"), ("py:class", "numpy._typing._array_like._SupportsArray"), ("py:class", "numpy._typing._nested_sequence._NestedSequence"), + ("py:class", "T"), + ("py:class", "TorchTensor"), + ("py:class", "TensorFlowTensor"), + ("py:class", "JaxArray"), ] intersphinx_mapping = { "networkx": ("https://networkx.org/documentation/stable/", None), "numpy": ("https://numpy.org/doc/stable/", None), "python": ("https://docs.python.org/3", None), + "torch": ("https://pytorch.org/docs/stable/", None), + "tensorflow": ( + "https://www.tensorflow.org/api_docs/python", + "https://github.com/GPflow/tensorflow-intersphinx/raw/master/tf2_py_objects.inv", + ), + "jax": ("https://jax.readthedocs.io/en/latest/", None), } diff --git a/docs/source/developer.rst b/docs/source/developer.rst index cfe3f03..2c1d871 100644 --- a/docs/source/developer.rst +++ b/docs/source/developer.rst @@ -36,7 +36,9 @@ To make sure all GitHub status checks pass, simply run :code:`tox`. You can also tox -e ruff # find formatting and style issues tox -e sphinx # rebuild documentation -Note: Running :code:`tox` commands can take several minutes and may require admin privileges. +Note: Running :code:`tox` commands can take several minutes and may require admin privileges. +If you have mypy addon installed in your IDE, you can use it to get instant feedback on typing issues while coding. +If mypy fails with ``KeyError: 'setter_type'``, delete the ``.mypy_cache`` folder in the project root. Tools can also be used directly (instead of via tox) after activating the dev environment. Useful examples include: diff --git a/docs/source/user.rst b/docs/source/user.rst index 3129dd8..0ee78a8 100644 --- a/docs/source/user.rst +++ b/docs/source/user.rst @@ -33,6 +33,16 @@ Benchmark algorithms on a regression problem without any communication constrain ) +Benchmark executions will have outputs like these: + +.. list-table:: + + * - .. image:: _static/table.png + :align: center + - .. image:: _static/plot.png + :align: center + + Execution settings ------------------ Configure settings for metrics, trials, statistical confidence level, logging, and multiprocessing. @@ -148,11 +158,16 @@ Create a custom benchmark problem using existing resources. from decent_bench.datasets import SyntheticClassificationData from decent_bench.distributed_algorithms import ADMM, DGD, ED from decent_bench.schemes import GaussianNoise, Quantization, UniformActivationRate, UniformDropRate + from decent_bench.utils.types import SupportedFrameworks n_agents = 100 dataset = SyntheticClassificationData( - n_classes=2, n_partitions=n_agents, n_samples_per_partition=10, n_features=3 + n_classes=2, + n_partitions=n_agents, + n_samples_per_partition=10, + n_features=3, + framework=SupportedFrameworks.NUMPY, ) costs = [LogisticRegressionCost(*p) for p in dataset.training_partitions()] @@ -237,6 +252,21 @@ corresponding abstracts. benchmark_problem=problem, ) +Interoperability requirement +---------------------------- +Decent-Bench is designed to interoperate with multiple array/tensor frameworks (NumPy, PyTorch, JAX, etc.). To keep +algorithms framework-agnostic, always use the interoperability layer :class:`~decent_bench.utils.interoperability`, aliased as +`iop`, and the :class:`~decent_bench.utils.array.Array` wrapper when creating, manipulating, and exchanging values: + +- Use :class:`decent_bench.utils.interoperability.zeros` instead of framework-specific constructors (e.g., `np.zeros`, `torch.zeros`). + Other examples are :meth:`~decent_bench.utils.interoperability.ones_like`, :meth:`~decent_bench.utils.interoperability.rand_like`, :meth:`~decent_bench.utils.interoperability.randn_like`, etc. + See :mod:`~decent_bench.utils.interoperability` for a full list of available methods and :mod:`~decent_bench.distributed_algorithms` for examples of usage. +- Avoid calling any framework-specific functions directly within your algorithm. + Let the :class:`~decent_bench.costs.Cost` implementations handle framework-specific details for + :func:`~decent_bench.costs.Cost.function`, :func:`~decent_bench.costs.Cost.gradient`, :func:`~decent_bench.costs.Cost.hessian`, and :func:`~decent_bench.costs.Cost.proximal`. +- When you need to create a new array/tensor, use the interoperability layer to ensure compatibility with the agent's cost function framework and device. + If a method to create your specific array/tensor is not available, see the implementation of :attr:`~decent_bench.networks.P2PNetwork.weights` as en example. + Algorithms ---------- @@ -244,11 +274,12 @@ Create a new algorithm to benchmark against existing ones. **Note**: In order for metrics to work, use :attr:`Agent.x ` to update the local primal variable. Similarly, in order for the benchmark problem's communication schemes to be applied, use the -:attr:`~decent_bench.networks.P2PNetwork` object to retrieve agents and to send and receive messages. +:attr:`~decent_bench.networks.P2PNetwork` object to retrieve agents and to send and receive messages. +Be sure to use :meth:`~decent_bench.networks.P2PNetwork.active_agents` to during algorithm runtime, so that asynchrony is properly handled. .. code-block:: python - import numpy as np + import decent_bench.utils.interoperability as iop from decent_bench import benchmark, benchmark_problem from decent_bench.costs import LinearRegressionCost @@ -256,17 +287,15 @@ variable. Similarly, in order for the benchmark problem's communication schemes from decent_bench.networks import P2PNetwork class MyNewAlgorithm(Algorithm): + iterations: int + step_size: float name: str = "MNA" - def __init__(self, iterations: int, step_size: float): - self.iterations = iterations - self.step_size = step_size - def run(self, network: P2PNetwork) -> None: - # Initialize agents + # Initialize agents with Array values using the interoperability layer for agent in network.agents(): - x0 = np.zeros(agent.cost.shape) - y0 = np.zeros(agent.cost.shape) + x0 = iop.zeros(shape=agent.cost.shape, framework=agent.cost.framework, device=agent.cost.device) + y0 = iop.zeros(shape=agent.cost.shape, framework=agent.cost.framework, device=agent.cost.device) neighbors = network.neighbors(agent) agent.initialize(x=x0, received_msgs=dict.fromkeys(neighbors, x0), aux_vars={"y": y0}) @@ -275,9 +304,8 @@ variable. Similarly, in order for the benchmark problem's communication schemes for k in range(self.iterations): for i in network.active_agents(k): i.aux_vars["y_new"] = i.x - self.step_size * i.cost.gradient(i.x) - neighborhood_avg = np.sum( - [W[i, j] * x_j for j, x_j in i.messages.items()], axis=0 - ) + s = iop.stack([W[i, j] * x_j for j, x_j in i.messages.items()]) + neighborhood_avg = iop.sum(s, axis=0) neighborhood_avg += W[i, i] * i.x i.x = i.aux_vars["y_new"] - i.aux_vars["y"] + neighborhood_avg i.aux_vars["y"] = i.aux_vars["y_new"] @@ -304,6 +332,7 @@ Create your own metrics to tabulate and/or plot. .. code-block:: python import numpy.linalg as la + import decent_bench.utils.interoperability as iop from decent_bench import benchmark, benchmark_problem from decent_bench.agents import AgentMetricsView @@ -314,7 +343,8 @@ Create your own metrics to tabulate and/or plot. from decent_bench.metrics.table_metrics import DEFAULT_TABLE_METRICS, TableMetric def x_error_at_iter(agent: AgentMetricsView, problem: BenchmarkProblem, i: int = -1) -> float: - return float(la.norm(problem.optimal_x - agent.x_per_iteration[i])) + # Convert Array values to numpy for custom metric computation + return float(la.norm(iop.to_numpy(problem.optimal_x) - iop.to_numpy(agent.x_per_iteration[i]))) class XError(TableMetric): description: str = "x error" @@ -350,13 +380,75 @@ Create your own metrics to tabulate and/or plot. ) -Output ------- -Benchmark executions will have outputs like these: +Cost Functions +-------------- +Create new cost functions by subclassing :class:`~decent_bench.costs.Cost` and using interoperability decorators to keep +your implementation framework-agnostic. The decorators automatically wrap inputs/outputs as `Array` and ensure +compatibility with the selected framework and device of your custom cost. -.. list-table:: +.. code-block:: python - * - .. image:: _static/table.png - :align: center - - .. image:: _static/plot.png - :align: center + from numpy import float64 + from numpy.typing import NDArray + + import decent_bench.utils.interoperability as iop + from decent_bench.costs import Cost, SumCost + from decent_bench.utils.types import SupportedFrameworks, SupportedDevices + + class MyCost(Cost): + def __init__(self, A: Array, b: Array): + # Convert any external arrays to Array using the chosen framework/device + self.A: NDArray[float64] = iop.to_numpy(A) + self.b: NDArray[float64] = iop.to_numpy(b) + + @property + def shape(self) -> tuple[int, ...]: + # Domain shape (e.g., dimension of x) + return (self.A.shape[1],) + + @property + def framework(self) -> str: + return SupportedFrameworks.NUMPY + + @property + def device(self) -> str | None: + return SupportedDevices.CPU + + @property + def m_smooth(self) -> float: + # Provide a meaningful smoothness constant if available + return 1.0 + + @property + def m_cvx(self) -> float: + # Provide convexity constant (0 if non-strongly convex) + return 0.0 + + @iop.autodecorate_cost_method(Cost.function) + def function(self, x: NDArray[float64]) -> float: + # Return a scalar (float) or Array scalar compatible with the framework + r = self.A @ x - self.b + return 0.5 * float(iop.dot(r, r)) + + @iop.autodecorate_cost_method(Cost.gradient) + def gradient(self, x: NDArray[float64]) -> NDArray[float64]: + # Return an Array with same shape as x + return self.A.T @ (self.A @ x - self.b) + + @iop.autodecorate_cost_method(Cost.hessian) + def hessian(self, x: NDArray[float64]) -> NDArray[float64]: + # Optional: return an Array representing the Hessian + return self.A.T @ self.A + + @iop.autodecorate_cost_method(Cost.proximal) + def proximal(self, y: NDArray[float64], rho: float) -> NDArray[float64]: + # Optional: provide a closed-form proximal if available + # Otherwise you can rely on `centralized_algorithms.proximal_solver`. + return y # identity as a placeholder + + def __add__(self, other: Cost) -> Cost: + # Support addition of costs + if self.shape != other.shape: + raise ValueError(f"Mismatching domain shapes: {self.shape} vs {other.shape}") + + return SumCost([self, other]) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 3d752a1..3f1821c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,10 +38,17 @@ dev = [ "types-tabulate", "torch", "torchvision", - "tensorflow", "types-tensorflow", +] +dev-cpu = [ + "tensorflow", "jax", ] +dev-gpu = [ + "tensorflow[and-cuda]", + "jax[cuda12]", +] + [build-system] requires = ["hatchling"] @@ -54,6 +61,16 @@ envlist = ["dev", "mypy", "pytest", "ruff", "sphinx"] description = "Generate dev venv with all dependencies, active with `source .tox/dev/bin/activate`" deps = [ ".[dev]", + ".[dev-cpu]", + "git+https://github.com/microsoft/python-type-stubs.git@main", + "git+https://github.com/pydata/pydata-sphinx-theme.git@main" +] + +[tool.tox.env.dev-gpu] +description = "Generate dev venv with all dependencies including GPU support, active with `source .tox/dev-gpu/bin/activate`" +deps = [ + ".[dev]", + ".[dev-gpu]", "git+https://github.com/microsoft/python-type-stubs.git@main", "git+https://github.com/pydata/pydata-sphinx-theme.git@main" ] @@ -68,7 +85,7 @@ commands = [ [tool.tox.env.mypy] description = "Run mypy (static type checker)" -deps = [".[dev]", "git+https://github.com/microsoft/python-type-stubs.git@main"] +deps = [".[dev]", ".[dev-cpu]", "git+https://github.com/microsoft/python-type-stubs.git@main"] commands = [["mypy", "decent_bench"]] [tool.tox.env.pytest] diff --git a/test/utils/test_array.py b/test/utils/test_array.py new file mode 100644 index 0000000..2ee7557 --- /dev/null +++ b/test/utils/test_array.py @@ -0,0 +1,468 @@ +import numpy as np +import pytest +from numpy.testing import assert_array_almost_equal as np_assert_almost_equal + +import decent_bench.utils.interoperability as iop +from decent_bench.utils.array import Array + +try: + import torch + from torch.testing import assert_close as torch_assert_close + + TORCH_AVAILABLE = True + TORCH_CUDA_AVAILABLE = torch.cuda.is_available() +except ModuleNotFoundError: + TORCH_AVAILABLE = False + TORCH_CUDA_AVAILABLE = False + +try: + import tensorflow as tf + + TF_AVAILABLE = True + TF_GPU_AVAILABLE = len(tf.config.list_physical_devices("GPU")) > 0 +except (ImportError, ModuleNotFoundError): + TF_AVAILABLE = False + TF_GPU_AVAILABLE = False + +try: + import jax + import jax.numpy as jnp + + JAX_AVAILABLE = True + JAX_GPU_AVAILABLE = len(jax.devices("gpu")) > 0 +except (ImportError, ModuleNotFoundError): + JAX_AVAILABLE = False + JAX_GPU_AVAILABLE = False +except RuntimeError: + # JAX raises RuntimeError if no GPU is available when querying devices + JAX_GPU_AVAILABLE = False + +# ============================================================================ +# Helpers +# ============================================================================ + + +def create_array(data: list, framework: str, device: str = "cpu"): + """ + Factory function to create arrays in different frameworks and devices. + + Args: + data: Python list/nested list to convert + framework: One of 'numpy', 'torch', 'tensorflow', 'jax' + device: 'cpu' or 'gpu' + + Returns: + Array in the specified framework and device + + """ # noqa: D401, DOC501 + if data is None: + raise ValueError("Data cannot be None") + + if framework == "numpy": + return Array(np.array(data, dtype=np.float32)) + if framework == "torch": + array1 = torch.tensor(data, dtype=torch.float32) + if device == "gpu" and TORCH_CUDA_AVAILABLE: + array1 = array1.to("cuda") + return Array(array1) + if framework == "tensorflow": + device_str = "/GPU:0" if device == "gpu" and TF_GPU_AVAILABLE else "/CPU:0" + with tf.device(device_str): + array2: tf.Tensor = tf.constant(data, dtype=tf.float32) # type: ignore + return Array(array2) + elif framework == "jax": + array3 = jnp.array(data, dtype=jnp.float32) + if device == "gpu" and JAX_GPU_AVAILABLE: + gpu_devices = [d for d in jax.devices() if d.platform == "gpu"] + if gpu_devices: + array3 = jax.device_put(array3, device=gpu_devices[0]) + elif device == "cpu": + cpu_devices = [d for d in jax.devices("cpu") if d.platform == "cpu"] + if cpu_devices: + array3 = jax.device_put(array3, device=cpu_devices[0]) + return Array(array3) + else: + raise ValueError(f"Unknown framework: {framework}") + + +# ============================================================================ +# Tests for Array class +# ============================================================================ + + +def assert_arrays_equal(result, expected, framework: str): + """Framework-agnostic assertion for array equality.""" + result_np = iop.to_numpy(result) + expected_np = iop.to_numpy(expected) + + if framework == "torch" and isinstance(result, torch.Tensor): + # For torch, use torch_assert_close if result is still a tensor + expected_torch = torch.tensor(expected_np).to(result.dtype) + if result.is_cuda: + expected_torch = expected_torch.to("cuda") + torch_assert_close(result, expected_torch) + else: + np_assert_almost_equal(result_np, expected_np) + + +@pytest.mark.parametrize( + ("framework", "device"), + [ + ("numpy", "cpu"), + pytest.param( + "torch", + "cpu", + marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available"), + ), + pytest.param( + "torch", + "gpu", + marks=pytest.mark.skipif(not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available"), + ), + pytest.param( + "tensorflow", + "cpu", + marks=pytest.mark.skipif(not TF_AVAILABLE, reason="TensorFlow not available"), + ), + pytest.param( + "tensorflow", + "gpu", + marks=pytest.mark.skipif(not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available"), + ), + pytest.param( + "jax", + "cpu", + marks=pytest.mark.skipif(not JAX_AVAILABLE, reason="JAX not available"), + ), + pytest.param( + "jax", + "gpu", + marks=pytest.mark.skipif(not JAX_GPU_AVAILABLE, reason="JAX GPU not available"), + ), + ], +) +class TestArrayOperators: + """Test suite for Array class operators.""" + + def test_add(self, framework: str, device: str) -> None: + """Test addition operator.""" + a = create_array([[1, 2], [3, 4]], framework, device) + b = create_array([[5, 6], [7, 8]], framework, device) + a_np = create_array([[1, 2], [3, 4]], "numpy") + b_np = create_array([[5, 6], [7, 8]], "numpy") + c = a + b + expected = a_np + b_np + assert_arrays_equal(c, expected, framework) + + d = c + a + expected = expected + a_np + assert_arrays_equal(d, expected, framework) + + c = a + 2 + expected = a_np + 2 + assert_arrays_equal(c, expected, framework) + + c = 2 + a + expected = 2 + a_np + assert_arrays_equal(c, expected, framework) + + def test_sub(self, framework: str, device: str) -> None: + """Test subtraction operator.""" + a = create_array([[1, 2], [3, 4]], framework, device) + b = create_array([[5, 6], [7, 8]], framework, device) + a_np = create_array([[1, 2], [3, 4]], "numpy") + b_np = create_array([[5, 6], [7, 8]], "numpy") + c = a - b + expected = a_np - b_np + assert_arrays_equal(c, expected, framework) + + d = c - a + expected = expected - a_np + assert_arrays_equal(d, expected, framework) + + c = a - 1 + expected = a_np - 1 + assert_arrays_equal(c, expected, framework) + + c = 5 - a + expected = 5 - a_np + assert_arrays_equal(c, expected, framework) + + def test_mul(self, framework: str, device: str) -> None: + """Test multiplication operator.""" + a = create_array([[1, 2], [3, 4]], framework, device) + b = create_array([[5, 6], [7, 8]], framework, device) + a_np = create_array([[1, 2], [3, 4]], "numpy") + b_np = create_array([[5, 6], [7, 8]], "numpy") + c = a * b + expected = a_np * b_np + assert_arrays_equal(c, expected, framework) + + d = c * a + expected = expected * a_np + assert_arrays_equal(d, expected, framework) + + c = a * 2 + expected = a_np * 2 + assert_arrays_equal(c, expected, framework) + + e = c * 2 + expected = expected * 2 + assert_arrays_equal(e, expected, framework) + + c = 2 * a + expected = 2 * a_np + assert_arrays_equal(c, expected, framework) + + def test_truediv(self, framework: str, device: str) -> None: + """Test true division operator.""" + a = create_array([[10, 20], [30, 40]], framework, device) + b = create_array([[2, 5], [10, 8]], framework, device) + a_np = create_array([[10, 20], [30, 40]], "numpy") + b_np = create_array([[2, 5], [10, 8]], "numpy") + c = a / b + expected = a_np / b_np + assert_arrays_equal(c, expected, framework) + + d = c / a + expected = expected / a_np + assert_arrays_equal(d, expected, framework) + + c = a / 10 + expected = a_np / 10 + assert_arrays_equal(c, expected, framework) + + c = 100 / a + expected = 100 / a_np + assert_arrays_equal(c, expected, framework) + + def test_matmul(self, framework: str, device: str) -> None: + """Test matrix multiplication operator.""" + a = create_array([[1, 2], [3, 4]], framework, device) + b = create_array([[5, 6], [7, 8]], framework, device) + a_np = create_array([[1, 2], [3, 4]], "numpy") + b_np = create_array([[5, 6], [7, 8]], "numpy") + c = a @ b + expected = a_np @ b_np + assert_arrays_equal(c, expected, framework) + + d = c @ a + expected = expected @ a_np + assert_arrays_equal(d, expected, framework) + + def test_pow(self, framework: str, device: str) -> None: + """Test power operator.""" + a = create_array([[1, 2], [3, 4]], framework, device) + a_np = create_array([[1, 2], [3, 4]], "numpy") + c = a**2 + expected = a_np**2 + assert_arrays_equal(c, expected, framework) + + d = c**2 + expected = expected**2 + assert_arrays_equal(d, expected, framework) + + def test_iadd(self, framework: str, device: str) -> None: + """Test in-place addition.""" + a = create_array([[1, 2], [3, 4]], framework, device) + b = create_array([[5, 6], [7, 8]], framework, device) + + a_np = create_array([[1, 2], [3, 4]], "numpy") + b_np = create_array([[5, 6], [7, 8]], "numpy") + a += b + expected = a_np + b_np + + assert_arrays_equal(a, expected, framework) + + a += 2 + expected = expected + 2 + assert_arrays_equal(a, expected, framework) + + a = create_array([[1, 2], [3, 4]], framework, device) + a += 2 + expected = a_np + 2 + assert_arrays_equal(a, expected, framework) + + def test_isub(self, framework: str, device: str) -> None: + """Test in-place subtraction.""" + a = create_array([[1, 2], [3, 4]], framework, device) + b = create_array([[5, 6], [7, 8]], framework, device) + a_np = create_array([[1, 2], [3, 4]], "numpy") + b_np = create_array([[5, 6], [7, 8]], "numpy") + a -= b + expected = a_np - b_np + assert_arrays_equal(a, expected, framework) + + a -= 2 + expected = expected - 2 + assert_arrays_equal(a, expected, framework) + + a = create_array([[1, 2], [3, 4]], framework, device) + a_np = create_array([[1, 2], [3, 4]], "numpy") + a -= 1 + expected = a_np - 1 + assert_arrays_equal(a, expected, framework) + + def test_imul(self, framework: str, device: str) -> None: + """Test in-place multiplication.""" + a = create_array([[1, 2], [3, 4]], framework, device) + b = create_array([[5, 6], [7, 8]], framework, device) + a_np = create_array([[1, 2], [3, 4]], "numpy") + b_np = create_array([[5, 6], [7, 8]], "numpy") + a *= b + expected = a_np * b_np + assert_arrays_equal(a, expected, framework) + + a *= 2 + expected = expected * 2 + assert_arrays_equal(a, expected, framework) + + a = create_array([[1, 2], [3, 4]], framework, device) + a *= 2 + expected = a_np * 2 + assert_arrays_equal(a, expected, framework) + + def test_itruediv(self, framework: str, device: str) -> None: + """Test in-place true division.""" + a = create_array([[10, 20], [30, 40]], framework, device) + b = create_array([[2, 5], [10, 8]], framework, device) + a_np = create_array([[10, 20], [30, 40]], "numpy") + b_np = create_array([[2, 5], [10, 8]], "numpy") + a /= b + expected = a_np / b_np + assert_arrays_equal(a, expected, framework) + + a /= b + expected = expected / b_np + assert_arrays_equal(a, expected, framework) + + a = create_array([[10, 20], [30, 40]], framework, device) + a /= 10 + expected = a_np / 10 + assert_arrays_equal(a, expected, framework) + + def test_ipow(self, framework: str, device: str) -> None: + """Test in-place power operator.""" + a = create_array([[1, 2], [3, 4]], framework, device) + a_np = create_array([[1, 2], [3, 4]], "numpy") + a **= 2 + expected = a_np**2 + assert_arrays_equal(a, expected, framework) + + a **= 2 + expected = expected**2 + assert_arrays_equal(a, expected, framework) + + def test_neg(self, framework: str, device: str) -> None: + """Test negation operator.""" + a = create_array([[1, -2], [-3, 4]], framework, device) + a_np = create_array([[1, -2], [-3, 4]], "numpy") + c = -a + expected = -a_np + assert_arrays_equal(c, expected, framework) + + d = -c + expected = -expected + assert_arrays_equal(d, expected, framework) + + def test_abs(self, framework: str, device: str) -> None: + """Test absolute value.""" + a = create_array([[1, -2], [-3, 4]], framework, device) + a_np = create_array([[1, -2], [-3, 4]], "numpy") + c = abs(a) + expected = abs(a_np) + assert_arrays_equal(c, expected, framework) + + d = abs(c) + expected = abs(expected) + assert_arrays_equal(d, expected, framework) + + def test_getitem(self, framework: str, device: str) -> None: + """Test __getitem__ method.""" + a = create_array([[1, 2, 3], [4, 5, 6]], framework, device) + a_np = create_array([[1, 2, 3], [4, 5, 6]], "numpy") + item = a[0, 1] + expected = a_np[0, 1] + assert_arrays_equal(item, expected, framework) + + slice_ = a[0, :] + expected = a_np[0, :] + assert_arrays_equal(slice_, expected, framework) + + def test_setitem(self, framework: str, device: str) -> None: + """Test __setitem__ method.""" + + if framework in ["jax", "tensorflow"]: + pytest.skip("Setitem not supported for JAX and TensorFlow due to immutability.") + + a = create_array([[1, 2], [3, 4]], framework, device) + a[0, 0] = 99.0 + expected = create_array([[1, 2], [3, 4]], "numpy") + expected[0, 0] = 99.0 + assert_arrays_equal(a, expected, framework) + + b = create_array([10, 20], framework, device) + b_np = create_array([10, 20], "numpy") + a[1, :] = b + expected[1, :] = b_np + assert_arrays_equal(a, expected, framework) + + def test_len(self, framework: str, device: str) -> None: + """Test __len__ method.""" + a = create_array([[1, 2, 3], [4, 5, 6]], framework, device) + assert len(a) == 2 + + a_scalar = create_array(5.0, framework, device) + with pytest.raises(TypeError): + len(a_scalar) + + def test_iter(self, framework: str, device: str) -> None: + """Test __iter__ method.""" + a = create_array([[1, 2], [3, 4]], framework, device) + it = iter(a) + row1 = next(it) + assert_arrays_equal(Array(row1), np.array([1, 2]), framework) + row2 = next(it) + assert_arrays_equal(Array(row2), np.array([3, 4]), framework) + with pytest.raises(StopIteration): + next(it) + + a_scalar = create_array(5, framework, device) + with pytest.raises(TypeError): + iter(a_scalar) + + def test_float(self, framework: str, device: str) -> None: + """Test __float__ method.""" + a = create_array(42.0, framework, device) + f = float(a) + assert isinstance(f, float) + assert f == 42.0 + + f = float(a) + assert isinstance(f, float) + assert f == 42.0 + + a_array = create_array([42.0], framework, device) + f = float(a_array) + assert isinstance(f, float) + assert f == 42.0 + + a_array_2d = create_array([[42.0]], framework, device) + f = float(a_array_2d) + assert isinstance(f, float) + assert f == 42.0 + + a_non_scalar = create_array([1.0, 2.0], framework, device) + with pytest.raises((TypeError, ValueError, RuntimeError)): + float(a_non_scalar) + + def test_combinations(self, framework: str, device: str) -> None: + """Test combinations of operations.""" + a = create_array([[1, 2], [3, 4]], framework, device) + b = create_array([[5, 6], [7, 8]], framework, device) + a_np = create_array([[1, 2], [3, 4]], "numpy") + b_np = create_array([[5, 6], [7, 8]], "numpy") + + c = (a + b) * 2 - 3 / (a - 0.5) + expected = (a_np + b_np) * 2 - 3 / (a_np - 0.5) + assert_arrays_equal(c, expected, framework) diff --git a/test/utils/test_interoperability.py b/test/utils/test_interoperability.py index 5004238..03db3f9 100644 --- a/test/utils/test_interoperability.py +++ b/test/utils/test_interoperability.py @@ -1,9 +1,13 @@ +from typing import Any + import numpy as np import pytest +from numpy.testing import assert_array_almost_equal as np_assert_almost_equal from numpy.testing import assert_array_equal as np_assert_equal -from typing import Any -import decent_bench.utils.interoperability as Interoperability +import decent_bench.utils.interoperability as iop +from decent_bench.utils.array import Array +from decent_bench.utils.types import SupportedDevices, SupportedFrameworks try: import torch @@ -43,56 +47,61 @@ def create_array(data: list, framework: str, device: str = "cpu"): - """Factory function to create arrays in different frameworks and devices. + """ + Factory function to create arrays in different frameworks and devices. Args: data: Python list/nested list to convert - framework: One of 'numpy', 'torch', 'tensorflow', 'jax', 'list', 'tuple' + framework: One of 'numpy', 'torch', 'tensorflow', 'jax' device: 'cpu' or 'gpu' - """ + + Returns: + Array in the specified framework and device + + """ # noqa: D401, DOC501 if data is None: raise ValueError("Data cannot be None") if framework == "numpy": - return np.array(data, dtype=np.float32) - elif framework == "torch": + return Array(np.array(data, dtype=np.float32)) + if framework == "torch": array1 = torch.tensor(data, dtype=torch.float32) if device == "gpu" and TORCH_CUDA_AVAILABLE: array1 = array1.to("cuda") - return array1 - elif framework == "tensorflow": + return Array(array1) + if framework == "tensorflow": device_str = "/GPU:0" if device == "gpu" and TF_GPU_AVAILABLE else "/CPU:0" with tf.device(device_str): array2: tf.Tensor = tf.constant(data, dtype=tf.float32) # type: ignore - return array2 + return Array(array2) elif framework == "jax": array3 = jnp.array(data, dtype=jnp.float32) if device == "gpu" and JAX_GPU_AVAILABLE: - gpu_devices = [d for d in jax.devices() if d.platform == "gpu"] + gpu_devices = [d for d in jax.devices("gpu") if d.platform == "gpu"] if gpu_devices: array3 = jax.device_put(array3, device=gpu_devices[0]) - return array3 - elif framework == "list": - return data - elif framework == "tuple": - return tuple(data) + elif device == "cpu": + cpu_devices = [d for d in jax.devices("cpu") if d.platform == "cpu"] + if cpu_devices: + array3 = jax.device_put(array3, device=cpu_devices[0]) + return Array(array3) else: raise ValueError(f"Unknown framework: {framework}") def assert_arrays_equal(result, expected, framework: str): """Framework-agnostic assertion for array equality.""" - result_np = Interoperability.to_numpy(result) - expected_np = Interoperability.to_numpy(expected) + result_np = iop.to_numpy(result) + expected_np = iop.to_numpy(expected) if framework == "torch" and isinstance(result, torch.Tensor): # For torch, use torch_assert_close if result is still a tensor - expected_torch = torch.from_numpy(expected_np) + expected_torch = torch.tensor(expected_np).to(result.dtype) if result.is_cuda: expected_torch = expected_torch.to("cuda") torch_assert_close(result, expected_torch) else: - np_assert_equal(result_np, expected_np) + np_assert_almost_equal(result_np, expected_np) def assert_shapes_equal(result, expected, framework): @@ -104,28 +113,17 @@ def assert_shapes_equal(result, expected, framework): def assert_same_type(result: Any, framework: str): """Assert that the result is of the expected type based on the framework.""" + if isinstance(result, Array): + result = result.value + if framework == "numpy": - assert "numpy" in str( - type(result) - ), f"Expected numpy.ndarray, got {type(result)}" + assert "numpy" in str(type(result)), f"Expected numpy.ndarray, got {type(result)}" elif framework == "torch": - assert "torch" in str( - type(result) - ), f"Expected torch.Tensor, got {type(result)}" + assert "torch" in str(type(result)), f"Expected torch.Tensor, got {type(result)}" elif framework == "tensorflow": - assert "tensorflow" in str( - type(result) - ), f"Expected tf.Tensor, got {type(result)}" + assert "tensorflow" in str(type(result)), f"Expected tf.Tensor, got {type(result)}" elif framework == "jax": assert "jax" in str(type(result)), f"Expected jnp.ndarray, got {type(result)}" - elif framework == "list": - assert isinstance(result, list) or isinstance( - result, (int, float, complex) - ), f"Expected list, got {type(result)}" - elif framework == "tuple": - assert isinstance(result, tuple) or isinstance( - result, (int, float, complex) - ), f"Expected tuple, got {type(result)}" else: raise ValueError(f"Unknown framework: {framework}") @@ -137,7 +135,7 @@ def assert_same_type(result: Any, framework: str): def test_numpy_passthrough(): arr = np.array([1, 2, 3], dtype=np.int32) - out = Interoperability.to_numpy(arr) + out = iop.to_numpy(arr) # Should return the same numpy array object assert out is arr np_assert_equal(out, np.array([1, 2, 3], dtype=np.int32)) @@ -146,29 +144,29 @@ def test_numpy_passthrough(): def test_scalars_and_none(): # None becomes a 0-d object array containing None # None should not be an input to to_numpy but we test it anyway - out = Interoperability.to_numpy(None) # type: ignore + out = iop.to_numpy(None) # type: ignore assert isinstance(out, np.ndarray) assert out.shape == () assert out.tolist() is None # Scalars become 0-d numpy arrays - out = Interoperability.to_numpy(5) + out = iop.to_numpy(5) assert isinstance(out, np.ndarray) assert out.shape == () assert out.item() == 5 - out = Interoperability.to_numpy(3.14) + out = iop.to_numpy(3.14) assert isinstance(out, np.ndarray) assert out.shape == () assert out.item() == pytest.approx(3.14) def test_list_of_scalars_conversion(): - out = Interoperability.to_numpy([1, 2, 3]) + out = iop.to_numpy([1, 2, 3]) assert isinstance(out, np.ndarray) np_assert_equal(out, np.array([1, 2, 3])) - out = Interoperability.to_numpy([1.5, 2.43, 3.0]) + out = iop.to_numpy([1.5, 2.43, 3.0]) assert isinstance(out, np.ndarray) np_assert_equal(out, np.array([1.5, 2.43, 3.0])) @@ -180,7 +178,7 @@ def test_dictionary_conversion(): "a": [np.array([1, 2]), 3], "b": (np.array([4]), {"c": np.array([5])}), } - out = Interoperability.to_numpy(nested) # type: ignore + out = iop.to_numpy(nested) # type: ignore assert isinstance(out, np.ndarray) assert out.shape == () assert out.dtype == object @@ -193,30 +191,22 @@ def test_dictionary_conversion(): pytest.param( "torch", "cpu", - marks=pytest.mark.skipif( - not TORCH_AVAILABLE, reason="PyTorch not available" - ), + marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available"), ), pytest.param( "torch", "gpu", - marks=pytest.mark.skipif( - not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available" - ), + marks=pytest.mark.skipif(not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available"), ), pytest.param( "tensorflow", "cpu", - marks=pytest.mark.skipif( - not TF_AVAILABLE, reason="TensorFlow not available" - ), + marks=pytest.mark.skipif(not TF_AVAILABLE, reason="TensorFlow not available"), ), pytest.param( "tensorflow", "gpu", - marks=pytest.mark.skipif( - not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available" - ), + marks=pytest.mark.skipif(not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available"), ), pytest.param( "jax", @@ -226,9 +216,7 @@ def test_dictionary_conversion(): pytest.param( "jax", "gpu", - marks=pytest.mark.skipif( - not JAX_GPU_AVAILABLE, reason="JAX GPU not available" - ), + marks=pytest.mark.skipif(not JAX_GPU_AVAILABLE, reason="JAX GPU not available"), ), ], ) @@ -236,7 +224,7 @@ def test_to_numpy_frameworks(framework: str, device: str): """Test to_numpy conversion for all frameworks and devices.""" data = [1, 2, 3] arr = create_array(data, framework, device) - out = Interoperability.to_numpy(arr) + out = iop.to_numpy(arr) assert isinstance(out, np.ndarray) np_assert_equal(out, np.array(data, dtype=np.float32)) @@ -248,30 +236,22 @@ def test_to_numpy_frameworks(framework: str, device: str): pytest.param( "torch", "cpu", - marks=pytest.mark.skipif( - not TORCH_AVAILABLE, reason="PyTorch not available" - ), + marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available"), ), pytest.param( "torch", "gpu", - marks=pytest.mark.skipif( - not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available" - ), + marks=pytest.mark.skipif(not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available"), ), pytest.param( "tensorflow", "cpu", - marks=pytest.mark.skipif( - not TF_AVAILABLE, reason="TensorFlow not available" - ), + marks=pytest.mark.skipif(not TF_AVAILABLE, reason="TensorFlow not available"), ), pytest.param( "tensorflow", "gpu", - marks=pytest.mark.skipif( - not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available" - ), + marks=pytest.mark.skipif(not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available"), ), pytest.param( "jax", @@ -281,24 +261,182 @@ def test_to_numpy_frameworks(framework: str, device: str): pytest.param( "jax", "gpu", - marks=pytest.mark.skipif( - not JAX_GPU_AVAILABLE, reason="JAX GPU not available" - ), + marks=pytest.mark.skipif(not JAX_GPU_AVAILABLE, reason="JAX GPU not available"), ), - ("list", "cpu"), - ("tuple", "cpu"), ], ) -def test_from_numpy_frameworks(framework, device: str): +def test_numpy_to_frameworks_like(framework, device: str): """Test from_numpy conversion for all frameworks and devices.""" like = create_array([1, 2], framework, device) data = [1, 2, 3] np_arr = np.array(data, dtype=np.float32) - out = Interoperability.from_numpy_like(np_arr, like) + out = iop.to_array_like(np_arr, like) + + assert isinstance(out, type(like.value)), f"Expected type {type(like.value)}, got {type(out)}" + + +# ============================================================================ +# Tests for Interoperability.to_torch +# ============================================================================ + + +@pytest.mark.parametrize( + "framework,device", + [ + pytest.param( + "torch", + "cpu", + marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available"), + ), + pytest.param( + "torch", + "gpu", + marks=pytest.mark.skipif(not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available"), + ), + pytest.param( + "tensorflow", + "cpu", + marks=pytest.mark.skipif(not TF_AVAILABLE, reason="TensorFlow not available"), + ), + pytest.param( + "tensorflow", + "gpu", + marks=pytest.mark.skipif(not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available"), + ), + pytest.param( + "jax", + "cpu", + marks=pytest.mark.skipif(not JAX_AVAILABLE, reason="JAX not available"), + ), + pytest.param( + "jax", + "gpu", + marks=pytest.mark.skipif(not JAX_GPU_AVAILABLE, reason="JAX GPU not available"), + ), + ], +) +def test_to_torch_frameworks(framework: str, device: str): + """Test to_torch conversion for all frameworks and devices.""" + data = [1, 2, 3] + arr = create_array(data, framework, device) + out = iop.to_torch(arr, SupportedDevices(device)) + + assert isinstance(out, torch.Tensor), f"Expected torch.Tensor, got {type(out)}" + assert out.device.type == ("cuda" if device == "gpu" and TORCH_CUDA_AVAILABLE else "cpu"), ( + f"Expected device {device}, got {out.device.type}" + ) + equals = create_array(data, "torch", device) + assert_arrays_equal(out, equals, "torch") + + +# ============================================================================ +# Tests for Interoperability.to_tensorflow +# ============================================================================ + + +@pytest.mark.parametrize( + "framework,device", + [ + pytest.param( + "torch", + "cpu", + marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available"), + ), + pytest.param( + "torch", + "gpu", + marks=pytest.mark.skipif(not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available"), + ), + pytest.param( + "tensorflow", + "cpu", + marks=pytest.mark.skipif(not TF_AVAILABLE, reason="TensorFlow not available"), + ), + pytest.param( + "tensorflow", + "gpu", + marks=pytest.mark.skipif(not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available"), + ), + pytest.param( + "jax", + "cpu", + marks=pytest.mark.skipif(not JAX_AVAILABLE, reason="JAX not available"), + ), + pytest.param( + "jax", + "gpu", + marks=pytest.mark.skipif(not JAX_GPU_AVAILABLE, reason="JAX GPU not available"), + ), + ], +) +def test_to_tensorflow_frameworks(framework: str, device: str): + """Test to_tensorflow conversion for all frameworks and devices.""" + data = [1, 2, 3] + arr = create_array(data, framework, device) + out = iop.to_tensorflow(arr, SupportedDevices(device)) - assert isinstance(out, type(like)) + assert isinstance(out, tf.Tensor), f"Expected tf.Tensor, got {type(out)}" + assert ("gpu" if device == "gpu" and TF_GPU_AVAILABLE else "cpu") in out.device.lower(), ( + f"Expected device {device}, got {out.device}" + ) + equals = create_array(data, "tensorflow", device) + assert_arrays_equal(out, equals, "tensorflow") + + +# ============================================================================ +# Tests for Interoperability.to_jax +# ============================================================================ + + +@pytest.mark.parametrize( + "framework,device", + [ + pytest.param( + "torch", + "cpu", + marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available"), + ), + pytest.param( + "torch", + "gpu", + marks=pytest.mark.skipif(not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available"), + ), + pytest.param( + "tensorflow", + "cpu", + marks=pytest.mark.skipif(not TF_AVAILABLE, reason="TensorFlow not available"), + ), + pytest.param( + "tensorflow", + "gpu", + marks=pytest.mark.skipif(not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available"), + ), + pytest.param( + "jax", + "cpu", + marks=pytest.mark.skipif(not JAX_AVAILABLE, reason="JAX not available"), + ), + pytest.param( + "jax", + "gpu", + marks=pytest.mark.skipif(not JAX_GPU_AVAILABLE, reason="JAX GPU not available"), + ), + ], +) +def test_to_jax_frameworks(framework: str, device: str): + """Test to_jax conversion for all frameworks and devices.""" + data = [1, 2, 3] + arr = create_array(data, framework, device) + out = iop.to_jax(arr, SupportedDevices(device)) + + assert isinstance(out, jax.Array), f"Expected jax.Array, got {type(out)}" + assert out.device.platform == ("gpu" if device == "gpu" and JAX_GPU_AVAILABLE else "cpu"), ( + f"Expected device {device}, got {out.device.platform}" + ) + equals = create_array(data, "jax", device) + assert_arrays_equal(out, equals, "jax") # ============================================================================ @@ -313,30 +451,22 @@ def test_from_numpy_frameworks(framework, device: str): pytest.param( "torch", "cpu", - marks=pytest.mark.skipif( - not TORCH_AVAILABLE, reason="PyTorch not available" - ), + marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available"), ), pytest.param( "torch", "gpu", - marks=pytest.mark.skipif( - not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available" - ), + marks=pytest.mark.skipif(not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available"), ), pytest.param( "tensorflow", "cpu", - marks=pytest.mark.skipif( - not TF_AVAILABLE, reason="TensorFlow not available" - ), + marks=pytest.mark.skipif(not TF_AVAILABLE, reason="TensorFlow not available"), ), pytest.param( "tensorflow", "gpu", - marks=pytest.mark.skipif( - not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available" - ), + marks=pytest.mark.skipif(not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available"), ), pytest.param( "jax", @@ -346,12 +476,8 @@ def test_from_numpy_frameworks(framework, device: str): pytest.param( "jax", "gpu", - marks=pytest.mark.skipif( - not JAX_GPU_AVAILABLE, reason="JAX GPU not available" - ), + marks=pytest.mark.skipif(not JAX_GPU_AVAILABLE, reason="JAX GPU not available"), ), - ("list", "cpu"), - ("tuple", "cpu"), ], ) @pytest.mark.parametrize( @@ -376,7 +502,7 @@ def test_sum_all_combinations(framework: str, device: str, dim, keepdims): np_arr = create_array(data, "numpy") expected = np.sum(np_arr, axis=dim, keepdims=keepdims) - result = Interoperability.sum(arr, dim=dim, keepdims=keepdims) + result = iop.sum(arr, dim=dim, keepdims=keepdims) assert_arrays_equal(result, expected, framework) assert_same_type(result, framework) @@ -388,30 +514,22 @@ def test_sum_all_combinations(framework: str, device: str, dim, keepdims): pytest.param( "torch", "cpu", - marks=pytest.mark.skipif( - not TORCH_AVAILABLE, reason="PyTorch not available" - ), + marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available"), ), pytest.param( "torch", "gpu", - marks=pytest.mark.skipif( - not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available" - ), + marks=pytest.mark.skipif(not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available"), ), pytest.param( "tensorflow", "cpu", - marks=pytest.mark.skipif( - not TF_AVAILABLE, reason="TensorFlow not available" - ), + marks=pytest.mark.skipif(not TF_AVAILABLE, reason="TensorFlow not available"), ), pytest.param( "tensorflow", "gpu", - marks=pytest.mark.skipif( - not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available" - ), + marks=pytest.mark.skipif(not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available"), ), pytest.param( "jax", @@ -421,12 +539,8 @@ def test_sum_all_combinations(framework: str, device: str, dim, keepdims): pytest.param( "jax", "gpu", - marks=pytest.mark.skipif( - not JAX_GPU_AVAILABLE, reason="JAX GPU not available" - ), + marks=pytest.mark.skipif(not JAX_GPU_AVAILABLE, reason="JAX GPU not available"), ), - ("list", "cpu"), - ("tuple", "cpu"), ], ) @pytest.mark.parametrize( @@ -451,7 +565,7 @@ def test_mean_all_combinations(framework: str, device: str, dim, keepdims): np_arr = create_array(data, "numpy") expected = np.mean(np_arr, axis=dim, keepdims=keepdims) - result = Interoperability.mean(arr, dim=dim, keepdims=keepdims) + result = iop.mean(arr, dim=dim, keepdims=keepdims) assert_arrays_equal(result, expected, framework) assert_same_type(result, framework) @@ -463,30 +577,22 @@ def test_mean_all_combinations(framework: str, device: str, dim, keepdims): pytest.param( "torch", "cpu", - marks=pytest.mark.skipif( - not TORCH_AVAILABLE, reason="PyTorch not available" - ), + marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available"), ), pytest.param( "torch", "gpu", - marks=pytest.mark.skipif( - not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available" - ), + marks=pytest.mark.skipif(not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available"), ), pytest.param( "tensorflow", "cpu", - marks=pytest.mark.skipif( - not TF_AVAILABLE, reason="TensorFlow not available" - ), + marks=pytest.mark.skipif(not TF_AVAILABLE, reason="TensorFlow not available"), ), pytest.param( "tensorflow", "gpu", - marks=pytest.mark.skipif( - not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available" - ), + marks=pytest.mark.skipif(not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available"), ), pytest.param( "jax", @@ -496,12 +602,8 @@ def test_mean_all_combinations(framework: str, device: str, dim, keepdims): pytest.param( "jax", "gpu", - marks=pytest.mark.skipif( - not JAX_GPU_AVAILABLE, reason="JAX GPU not available" - ), + marks=pytest.mark.skipif(not JAX_GPU_AVAILABLE, reason="JAX GPU not available"), ), - ("list", "cpu"), - ("tuple", "cpu"), ], ) @pytest.mark.parametrize( @@ -526,7 +628,7 @@ def test_min_all_combinations(framework: str, device: str, dim, keepdims): np_arr = create_array(data, "numpy") expected = np.min(np_arr, axis=dim, keepdims=keepdims) - result = Interoperability.min(arr, dim=dim, keepdims=keepdims) + result = iop.min(arr, dim=dim, keepdims=keepdims) assert_arrays_equal(result, expected, framework) assert_same_type(result, framework) @@ -538,30 +640,22 @@ def test_min_all_combinations(framework: str, device: str, dim, keepdims): pytest.param( "torch", "cpu", - marks=pytest.mark.skipif( - not TORCH_AVAILABLE, reason="PyTorch not available" - ), + marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available"), ), pytest.param( "torch", "gpu", - marks=pytest.mark.skipif( - not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available" - ), + marks=pytest.mark.skipif(not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available"), ), pytest.param( "tensorflow", "cpu", - marks=pytest.mark.skipif( - not TF_AVAILABLE, reason="TensorFlow not available" - ), + marks=pytest.mark.skipif(not TF_AVAILABLE, reason="TensorFlow not available"), ), pytest.param( "tensorflow", "gpu", - marks=pytest.mark.skipif( - not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available" - ), + marks=pytest.mark.skipif(not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available"), ), pytest.param( "jax", @@ -571,12 +665,8 @@ def test_min_all_combinations(framework: str, device: str, dim, keepdims): pytest.param( "jax", "gpu", - marks=pytest.mark.skipif( - not JAX_GPU_AVAILABLE, reason="JAX GPU not available" - ), + marks=pytest.mark.skipif(not JAX_GPU_AVAILABLE, reason="JAX GPU not available"), ), - ("list", "cpu"), - ("tuple", "cpu"), ], ) @pytest.mark.parametrize( @@ -601,7 +691,7 @@ def test_max_all_combinations(framework: str, device: str, dim, keepdims): np_arr = create_array(data, "numpy") expected = np.max(np_arr, axis=dim, keepdims=keepdims) - result = Interoperability.max(arr, dim=dim, keepdims=keepdims) + result = iop.max(arr, dim=dim, keepdims=keepdims) assert_arrays_equal(result, expected, framework) assert_same_type(result, framework) @@ -618,30 +708,22 @@ def test_max_all_combinations(framework: str, device: str, dim, keepdims): pytest.param( "torch", "cpu", - marks=pytest.mark.skipif( - not TORCH_AVAILABLE, reason="PyTorch not available" - ), + marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available"), ), pytest.param( "torch", "gpu", - marks=pytest.mark.skipif( - not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available" - ), + marks=pytest.mark.skipif(not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available"), ), pytest.param( "tensorflow", "cpu", - marks=pytest.mark.skipif( - not TF_AVAILABLE, reason="TensorFlow not available" - ), + marks=pytest.mark.skipif(not TF_AVAILABLE, reason="TensorFlow not available"), ), pytest.param( "tensorflow", "gpu", - marks=pytest.mark.skipif( - not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available" - ), + marks=pytest.mark.skipif(not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available"), ), pytest.param( "jax", @@ -651,12 +733,8 @@ def test_max_all_combinations(framework: str, device: str, dim, keepdims): pytest.param( "jax", "gpu", - marks=pytest.mark.skipif( - not JAX_GPU_AVAILABLE, reason="JAX GPU not available" - ), + marks=pytest.mark.skipif(not JAX_GPU_AVAILABLE, reason="JAX GPU not available"), ), - ("list", "cpu"), - ("tuple", "cpu"), ], ) @pytest.mark.parametrize( @@ -679,7 +757,7 @@ def test_argmax_all_combinations(framework: str, device: str, dim, keepdims): np_arr = create_array(data, "numpy") expected = np.argmax(np_arr, axis=dim, keepdims=keepdims) - result = Interoperability.argmax(arr, dim=dim, keepdims=keepdims) + result = iop.argmax(arr, dim=dim, keepdims=keepdims) assert_arrays_equal(result, expected, framework) assert_same_type(result, framework) @@ -691,30 +769,22 @@ def test_argmax_all_combinations(framework: str, device: str, dim, keepdims): pytest.param( "torch", "cpu", - marks=pytest.mark.skipif( - not TORCH_AVAILABLE, reason="PyTorch not available" - ), + marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available"), ), pytest.param( "torch", "gpu", - marks=pytest.mark.skipif( - not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available" - ), + marks=pytest.mark.skipif(not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available"), ), pytest.param( "tensorflow", "cpu", - marks=pytest.mark.skipif( - not TF_AVAILABLE, reason="TensorFlow not available" - ), + marks=pytest.mark.skipif(not TF_AVAILABLE, reason="TensorFlow not available"), ), pytest.param( "tensorflow", "gpu", - marks=pytest.mark.skipif( - not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available" - ), + marks=pytest.mark.skipif(not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available"), ), pytest.param( "jax", @@ -724,12 +794,8 @@ def test_argmax_all_combinations(framework: str, device: str, dim, keepdims): pytest.param( "jax", "gpu", - marks=pytest.mark.skipif( - not JAX_GPU_AVAILABLE, reason="JAX GPU not available" - ), + marks=pytest.mark.skipif(not JAX_GPU_AVAILABLE, reason="JAX GPU not available"), ), - ("list", "cpu"), - ("tuple", "cpu"), ], ) @pytest.mark.parametrize( @@ -752,7 +818,7 @@ def test_argmin_all_combinations(framework: str, device: str, dim, keepdims): np_arr = create_array(data, "numpy") expected = np.argmin(np_arr, axis=dim, keepdims=keepdims) - result = Interoperability.argmin(arr, dim=dim, keepdims=keepdims) + result = iop.argmin(arr, dim=dim, keepdims=keepdims) assert_arrays_equal(result, expected, framework) assert_same_type(result, framework) @@ -769,30 +835,22 @@ def test_argmin_all_combinations(framework: str, device: str, dim, keepdims): pytest.param( "torch", "cpu", - marks=pytest.mark.skipif( - not TORCH_AVAILABLE, reason="PyTorch not available" - ), + marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available"), ), pytest.param( "torch", "gpu", - marks=pytest.mark.skipif( - not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available" - ), + marks=pytest.mark.skipif(not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available"), ), pytest.param( "tensorflow", "cpu", - marks=pytest.mark.skipif( - not TF_AVAILABLE, reason="TensorFlow not available" - ), + marks=pytest.mark.skipif(not TF_AVAILABLE, reason="TensorFlow not available"), ), pytest.param( "tensorflow", "gpu", - marks=pytest.mark.skipif( - not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available" - ), + marks=pytest.mark.skipif(not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available"), ), pytest.param( "jax", @@ -802,19 +860,15 @@ def test_argmin_all_combinations(framework: str, device: str, dim, keepdims): pytest.param( "jax", "gpu", - marks=pytest.mark.skipif( - not JAX_GPU_AVAILABLE, reason="JAX GPU not available" - ), + marks=pytest.mark.skipif(not JAX_GPU_AVAILABLE, reason="JAX GPU not available"), ), - ("list", "cpu"), - ("tuple", "cpu"), ], ) def test_copy_frameworks(framework: str, device: str): """Test copy function for all frameworks and devices.""" data = [[1, 2, 3], [4, 5, 6]] arr = create_array(data, framework, device) - arr_copy = Interoperability.copy(arr) + arr_copy = iop.copy(arr) # Ensure the copied array is equal to the original assert_arrays_equal(arr_copy, arr, framework) @@ -854,30 +908,22 @@ def test_copy_frameworks(framework: str, device: str): pytest.param( "torch", "cpu", - marks=pytest.mark.skipif( - not TORCH_AVAILABLE, reason="PyTorch not available" - ), + marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available"), ), pytest.param( "torch", "gpu", - marks=pytest.mark.skipif( - not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available" - ), + marks=pytest.mark.skipif(not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available"), ), pytest.param( "tensorflow", "cpu", - marks=pytest.mark.skipif( - not TF_AVAILABLE, reason="TensorFlow not available" - ), + marks=pytest.mark.skipif(not TF_AVAILABLE, reason="TensorFlow not available"), ), pytest.param( "tensorflow", "gpu", - marks=pytest.mark.skipif( - not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available" - ), + marks=pytest.mark.skipif(not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available"), ), pytest.param( "jax", @@ -887,12 +933,8 @@ def test_copy_frameworks(framework: str, device: str): pytest.param( "jax", "gpu", - marks=pytest.mark.skipif( - not JAX_GPU_AVAILABLE, reason="JAX GPU not available" - ), + marks=pytest.mark.skipif(not JAX_GPU_AVAILABLE, reason="JAX GPU not available"), ), - ("list", "cpu"), - ("tuple", "cpu"), ], ) @pytest.mark.parametrize( @@ -905,12 +947,12 @@ def test_stack_frameworks(framework: str, device: str, dim: int): data2 = [[7, 8, 9], [10, 11, 12]] arr1 = create_array(data1, framework, device) arr2 = create_array(data2, framework, device) - stacked = Interoperability.stack([arr1, arr2], dim=dim) + stacked = iop.stack([arr1, arr2], dim=dim) # Compute expected result using numpy np_arr1 = create_array(data1, "numpy") np_arr2 = create_array(data2, "numpy") - expected = np.stack([np_arr1, np_arr2], axis=dim) + expected = np.stack([np_arr1.value, np_arr2.value], axis=dim) assert_arrays_equal(stacked, expected, framework) assert_same_type(stacked, framework) @@ -928,30 +970,22 @@ def test_stack_frameworks(framework: str, device: str, dim: int): pytest.param( "torch", "cpu", - marks=pytest.mark.skipif( - not TORCH_AVAILABLE, reason="PyTorch not available" - ), + marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available"), ), pytest.param( "torch", "gpu", - marks=pytest.mark.skipif( - not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available" - ), + marks=pytest.mark.skipif(not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available"), ), pytest.param( "tensorflow", "cpu", - marks=pytest.mark.skipif( - not TF_AVAILABLE, reason="TensorFlow not available" - ), + marks=pytest.mark.skipif(not TF_AVAILABLE, reason="TensorFlow not available"), ), pytest.param( "tensorflow", "gpu", - marks=pytest.mark.skipif( - not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available" - ), + marks=pytest.mark.skipif(not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available"), ), pytest.param( "jax", @@ -961,25 +995,19 @@ def test_stack_frameworks(framework: str, device: str, dim: int): pytest.param( "jax", "gpu", - marks=pytest.mark.skipif( - not JAX_GPU_AVAILABLE, reason="JAX GPU not available" - ), + marks=pytest.mark.skipif(not JAX_GPU_AVAILABLE, reason="JAX GPU not available"), ), - ("list", "cpu"), - ("tuple", "cpu"), ], ) @pytest.mark.parametrize( "new_shape", [(3, 2), (2, 3), (6,), (-1,), (2, 1, 3), (1, 6)], ) -def test_reshape_matrix_frameworks( - framework: str, device: str, new_shape: tuple[int, ...] -): +def test_reshape_matrix_frameworks(framework: str, device: str, new_shape: tuple[int, ...]): """Test reshape function for all frameworks and devices.""" data = [[1, 2, 3], [4, 5, 6]] arr = create_array(data, framework, device) - reshaped = Interoperability.reshape(arr, new_shape) + reshaped = iop.reshape(arr, new_shape) # Compute expected result using numpy np_arr = create_array(data, "numpy") @@ -997,30 +1025,22 @@ def test_reshape_matrix_frameworks( pytest.param( "torch", "cpu", - marks=pytest.mark.skipif( - not TORCH_AVAILABLE, reason="PyTorch not available" - ), + marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available"), ), pytest.param( "torch", "gpu", - marks=pytest.mark.skipif( - not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available" - ), + marks=pytest.mark.skipif(not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available"), ), pytest.param( "tensorflow", "cpu", - marks=pytest.mark.skipif( - not TF_AVAILABLE, reason="TensorFlow not available" - ), + marks=pytest.mark.skipif(not TF_AVAILABLE, reason="TensorFlow not available"), ), pytest.param( "tensorflow", "gpu", - marks=pytest.mark.skipif( - not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available" - ), + marks=pytest.mark.skipif(not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available"), ), pytest.param( "jax", @@ -1030,25 +1050,19 @@ def test_reshape_matrix_frameworks( pytest.param( "jax", "gpu", - marks=pytest.mark.skipif( - not JAX_GPU_AVAILABLE, reason="JAX GPU not available" - ), + marks=pytest.mark.skipif(not JAX_GPU_AVAILABLE, reason="JAX GPU not available"), ), - ("list", "cpu"), - ("tuple", "cpu"), ], ) @pytest.mark.parametrize( "new_shape", [(3, 2), (2, 3), (6,), (-1,), (2, 1, 3), (1, 6)], ) -def test_reshape_vector_frameworks( - framework: str, device: str, new_shape: tuple[int, ...] -): +def test_reshape_vector_frameworks(framework: str, device: str, new_shape: tuple[int, ...]): """Test reshape function for all frameworks and devices.""" data = [1, 2, 3, 4, 5, 6] arr = create_array(data, framework, device) - reshaped = Interoperability.reshape(arr, new_shape) + reshaped = iop.reshape(arr, new_shape) # Compute expected result using numpy np_arr = create_array(data, "numpy") @@ -1071,30 +1085,22 @@ def test_reshape_vector_frameworks( pytest.param( "torch", "cpu", - marks=pytest.mark.skipif( - not TORCH_AVAILABLE, reason="PyTorch not available" - ), + marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available"), ), pytest.param( "torch", "gpu", - marks=pytest.mark.skipif( - not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available" - ), + marks=pytest.mark.skipif(not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available"), ), pytest.param( "tensorflow", "cpu", - marks=pytest.mark.skipif( - not TF_AVAILABLE, reason="TensorFlow not available" - ), + marks=pytest.mark.skipif(not TF_AVAILABLE, reason="TensorFlow not available"), ), pytest.param( "tensorflow", "gpu", - marks=pytest.mark.skipif( - not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available" - ), + marks=pytest.mark.skipif(not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available"), ), pytest.param( "jax", @@ -1104,12 +1110,8 @@ def test_reshape_vector_frameworks( pytest.param( "jax", "gpu", - marks=pytest.mark.skipif( - not JAX_GPU_AVAILABLE, reason="JAX GPU not available" - ), + marks=pytest.mark.skipif(not JAX_GPU_AVAILABLE, reason="JAX GPU not available"), ), - ("list", "cpu"), - ("tuple", "cpu"), ], ) @pytest.mark.parametrize( @@ -1120,8 +1122,8 @@ def test_zeros_like_frameworks(framework: str, device: str, shape: tuple[int, .. """Test zeros_like function for all frameworks and devices.""" data = [1, 2, 3, 4, 5, 6] arr = create_array(data, framework, device) - arr = Interoperability.reshape(arr, shape) - zeros = Interoperability.zeros_like(arr) + arr = iop.reshape(arr, shape) + zeros = iop.zeros_like(arr) # Compute expected result using numpy np_arr = create_array(data, "numpy") @@ -1140,30 +1142,22 @@ def test_zeros_like_frameworks(framework: str, device: str, shape: tuple[int, .. pytest.param( "torch", "cpu", - marks=pytest.mark.skipif( - not TORCH_AVAILABLE, reason="PyTorch not available" - ), + marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available"), ), pytest.param( "torch", "gpu", - marks=pytest.mark.skipif( - not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available" - ), + marks=pytest.mark.skipif(not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available"), ), pytest.param( "tensorflow", "cpu", - marks=pytest.mark.skipif( - not TF_AVAILABLE, reason="TensorFlow not available" - ), + marks=pytest.mark.skipif(not TF_AVAILABLE, reason="TensorFlow not available"), ), pytest.param( "tensorflow", "gpu", - marks=pytest.mark.skipif( - not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available" - ), + marks=pytest.mark.skipif(not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available"), ), pytest.param( "jax", @@ -1173,12 +1167,8 @@ def test_zeros_like_frameworks(framework: str, device: str, shape: tuple[int, .. pytest.param( "jax", "gpu", - marks=pytest.mark.skipif( - not JAX_GPU_AVAILABLE, reason="JAX GPU not available" - ), + marks=pytest.mark.skipif(not JAX_GPU_AVAILABLE, reason="JAX GPU not available"), ), - ("list", "cpu"), - ("tuple", "cpu"), ], ) @pytest.mark.parametrize( @@ -1189,8 +1179,8 @@ def test_ones_like_frameworks(framework: str, device: str, shape: tuple[int, ... """Test ones_like function for all frameworks and devices.""" data = [1, 2, 3, 4, 5, 6] arr = create_array(data, framework, device) - arr = Interoperability.reshape(arr, shape) - ones = Interoperability.ones_like(arr) + arr = iop.reshape(arr, shape) + ones = iop.ones_like(arr) # Compute expected result using numpy np_arr = create_array(data, "numpy") @@ -1209,30 +1199,22 @@ def test_ones_like_frameworks(framework: str, device: str, shape: tuple[int, ... pytest.param( "torch", "cpu", - marks=pytest.mark.skipif( - not TORCH_AVAILABLE, reason="PyTorch not available" - ), + marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available"), ), pytest.param( "torch", "gpu", - marks=pytest.mark.skipif( - not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available" - ), + marks=pytest.mark.skipif(not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available"), ), pytest.param( "tensorflow", "cpu", - marks=pytest.mark.skipif( - not TF_AVAILABLE, reason="TensorFlow not available" - ), + marks=pytest.mark.skipif(not TF_AVAILABLE, reason="TensorFlow not available"), ), pytest.param( "tensorflow", "gpu", - marks=pytest.mark.skipif( - not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available" - ), + marks=pytest.mark.skipif(not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available"), ), pytest.param( "jax", @@ -1242,12 +1224,8 @@ def test_ones_like_frameworks(framework: str, device: str, shape: tuple[int, ... pytest.param( "jax", "gpu", - marks=pytest.mark.skipif( - not JAX_GPU_AVAILABLE, reason="JAX GPU not available" - ), + marks=pytest.mark.skipif(not JAX_GPU_AVAILABLE, reason="JAX GPU not available"), ), - ("list", "cpu"), - ("tuple", "cpu"), ], ) @pytest.mark.parametrize( @@ -1258,8 +1236,8 @@ def test_rand_like_frameworks(framework: str, device: str, shape: tuple[int, ... """Test rand_like function for all frameworks and devices.""" data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] arr = create_array(data, framework, device) - arr = Interoperability.reshape(arr, shape) - rand_arr = Interoperability.rand_like(arr) + arr = iop.reshape(arr, shape) + rand_arr = iop.rand_like(arr) # Compute expected shape using numpy np_arr = create_array(data, "numpy") @@ -1276,30 +1254,22 @@ def test_rand_like_frameworks(framework: str, device: str, shape: tuple[int, ... pytest.param( "torch", "cpu", - marks=pytest.mark.skipif( - not TORCH_AVAILABLE, reason="PyTorch not available" - ), + marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available"), ), pytest.param( "torch", "gpu", - marks=pytest.mark.skipif( - not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available" - ), + marks=pytest.mark.skipif(not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available"), ), pytest.param( "tensorflow", "cpu", - marks=pytest.mark.skipif( - not TF_AVAILABLE, reason="TensorFlow not available" - ), + marks=pytest.mark.skipif(not TF_AVAILABLE, reason="TensorFlow not available"), ), pytest.param( "tensorflow", "gpu", - marks=pytest.mark.skipif( - not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available" - ), + marks=pytest.mark.skipif(not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available"), ), pytest.param( "jax", @@ -1309,12 +1279,8 @@ def test_rand_like_frameworks(framework: str, device: str, shape: tuple[int, ... pytest.param( "jax", "gpu", - marks=pytest.mark.skipif( - not JAX_GPU_AVAILABLE, reason="JAX GPU not available" - ), + marks=pytest.mark.skipif(not JAX_GPU_AVAILABLE, reason="JAX GPU not available"), ), - ("list", "cpu"), - ("tuple", "cpu"), ], ) @pytest.mark.parametrize( @@ -1325,8 +1291,8 @@ def test_randn_like_frameworks(framework: str, device: str, shape: tuple[int, .. """Test randn_like function for all frameworks and devices.""" data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] arr = create_array(data, framework, device) - arr = Interoperability.reshape(arr, shape) - rand_arr = Interoperability.randn_like(arr) + arr = iop.reshape(arr, shape) + rand_arr = iop.randn_like(arr) # Compute expected shape using numpy np_arr = create_array(data, "numpy") @@ -1337,36 +1303,28 @@ def test_randn_like_frameworks(framework: str, device: str, shape: tuple[int, .. @pytest.mark.parametrize( - "framework,device", + ("framework", "device"), [ ("numpy", "cpu"), pytest.param( "torch", "cpu", - marks=pytest.mark.skipif( - not TORCH_AVAILABLE, reason="PyTorch not available" - ), + marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available"), ), pytest.param( "torch", "gpu", - marks=pytest.mark.skipif( - not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available" - ), + marks=pytest.mark.skipif(not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available"), ), pytest.param( "tensorflow", "cpu", - marks=pytest.mark.skipif( - not TF_AVAILABLE, reason="TensorFlow not available" - ), + marks=pytest.mark.skipif(not TF_AVAILABLE, reason="TensorFlow not available"), ), pytest.param( "tensorflow", "gpu", - marks=pytest.mark.skipif( - not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available" - ), + marks=pytest.mark.skipif(not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available"), ), pytest.param( "jax", @@ -1376,24 +1334,71 @@ def test_randn_like_frameworks(framework: str, device: str, shape: tuple[int, .. pytest.param( "jax", "gpu", - marks=pytest.mark.skipif( - not JAX_GPU_AVAILABLE, reason="JAX GPU not available" - ), + marks=pytest.mark.skipif(not JAX_GPU_AVAILABLE, reason="JAX GPU not available"), + ), + ], +) +@pytest.mark.parametrize( + "n", + [1, 2, 4], +) +def test_eye_frameworks(framework: str, device: str, n: int) -> None: + """Test eye function for all frameworks and devices.""" + eye_arr = iop.eye(n, SupportedFrameworks(framework), SupportedDevices(device)) + + # Compute expected result using numpy + expected = np.eye(n, dtype=np.float64) + + assert_shapes_equal(eye_arr, expected, framework) + assert_arrays_equal(eye_arr, expected, framework) + + +@pytest.mark.parametrize( + ("framework", "device"), + [ + ("numpy", "cpu"), + pytest.param( + "torch", + "cpu", + marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available"), + ), + pytest.param( + "torch", + "gpu", + marks=pytest.mark.skipif(not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available"), + ), + pytest.param( + "tensorflow", + "cpu", + marks=pytest.mark.skipif(not TF_AVAILABLE, reason="TensorFlow not available"), + ), + pytest.param( + "tensorflow", + "gpu", + marks=pytest.mark.skipif(not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available"), + ), + pytest.param( + "jax", + "cpu", + marks=pytest.mark.skipif(not JAX_AVAILABLE, reason="JAX not available"), + ), + pytest.param( + "jax", + "gpu", + marks=pytest.mark.skipif(not JAX_GPU_AVAILABLE, reason="JAX GPU not available"), ), - ("list", "cpu"), - ("tuple", "cpu"), ], ) @pytest.mark.parametrize( "shape", - [(4, 4), (2, 8), (1, 16), (-1,), (1, 2, 8)], + [(4, 4), (2, 8), (1, 16), (1, 2, 8)], ) -def test_eye_like_frameworks(framework: str, device: str, shape: tuple[int, ...]): +def test_eye_like_frameworks(framework: str, device: str, shape: tuple[int, ...]) -> None: """Test eye_like function for all frameworks and devices.""" data = list(range(16)) arr = create_array(data, framework, device) - arr = Interoperability.reshape(arr, shape) - eye_arr = Interoperability.eye_like(arr) + arr = iop.reshape(arr, shape) + eye_arr = iop.eye_like(arr) # Compute expected result using numpy np_arr = create_array(data, "numpy") @@ -1401,9 +1406,370 @@ def test_eye_like_frameworks(framework: str, device: str, shape: tuple[int, ...] expected = ( np.eye(*np_arr.shape[-2:], dtype=np_arr.dtype) if len(np_arr.shape) >= 2 - else np.eye(np_arr.shape[0], dtype=np_arr.dtype) # type: ignore + else np.eye(np_arr.shape[0], dtype=np_arr.dtype) ) assert_shapes_equal(eye_arr, expected, framework) assert_arrays_equal(eye_arr, expected, framework) assert_same_type(eye_arr, framework) + + +# ============================================================================ +# Tests for transpose +# ============================================================================ + + +@pytest.mark.parametrize( + ("framework", "device"), + [ + ("numpy", "cpu"), + pytest.param( + "torch", + "cpu", + marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available"), + ), + pytest.param( + "torch", + "gpu", + marks=pytest.mark.skipif(not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available"), + ), + pytest.param( + "tensorflow", + "cpu", + marks=pytest.mark.skipif(not TF_AVAILABLE, reason="TensorFlow not available"), + ), + pytest.param( + "tensorflow", + "gpu", + marks=pytest.mark.skipif(not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available"), + ), + pytest.param( + "jax", + "cpu", + marks=pytest.mark.skipif(not JAX_AVAILABLE, reason="JAX not available"), + ), + pytest.param( + "jax", + "gpu", + marks=pytest.mark.skipif(not JAX_GPU_AVAILABLE, reason="JAX GPU not available"), + ), + ], +) +@pytest.mark.parametrize( + "dims", + [None, (1, 0, 2), (2, 1, 0)], +) +def test_transpose_frameworks(framework: str, device: str, dims: tuple[int, ...] | None) -> None: + """Test transpose function for all frameworks and devices.""" + data = np.arange(24).reshape((2, 3, 4)) + arr = create_array(data.tolist(), framework, device) + transposed_arr = iop.transpose(arr, dim=dims) + + # Compute expected result using numpy + np_arr = create_array(data.tolist(), "numpy") + expected = np.transpose(np_arr, axes=dims) + + assert_shapes_equal(transposed_arr, expected, framework) + assert_arrays_equal(transposed_arr, expected, framework) + assert_same_type(transposed_arr, framework) + + +# ============================================================================ +# Tests for shape +# ============================================================================ + + +@pytest.mark.parametrize( + ("framework", "device"), + [ + ("numpy", "cpu"), + pytest.param( + "torch", + "cpu", + marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available"), + ), + pytest.param( + "torch", + "gpu", + marks=pytest.mark.skipif(not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available"), + ), + pytest.param( + "tensorflow", + "cpu", + marks=pytest.mark.skipif(not TF_AVAILABLE, reason="TensorFlow not available"), + ), + pytest.param( + "tensorflow", + "gpu", + marks=pytest.mark.skipif(not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available"), + ), + pytest.param( + "jax", + "cpu", + marks=pytest.mark.skipif(not JAX_AVAILABLE, reason="JAX not available"), + ), + pytest.param( + "jax", + "gpu", + marks=pytest.mark.skipif(not JAX_GPU_AVAILABLE, reason="JAX GPU not available"), + ), + ], +) +def test_shape_frameworks(framework: str, device: str) -> None: + """Test shape function for all frameworks and devices.""" + data = [[1, 2, 3], [4, 5, 6]] + arr = create_array(data, framework, device) + shape = iop.shape(arr) + assert shape == (2, 3) + + +# ============================================================================ +# Tests for zeros +# ============================================================================ + + +@pytest.mark.parametrize( + ("framework", "device"), + [ + ("numpy", "cpu"), + pytest.param( + "torch", + "cpu", + marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available"), + ), + pytest.param( + "torch", + "gpu", + marks=pytest.mark.skipif(not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available"), + ), + pytest.param( + "tensorflow", + "cpu", + marks=pytest.mark.skipif(not TF_AVAILABLE, reason="TensorFlow not available"), + ), + pytest.param( + "tensorflow", + "gpu", + marks=pytest.mark.skipif(not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available"), + ), + pytest.param( + "jax", + "cpu", + marks=pytest.mark.skipif(not JAX_AVAILABLE, reason="JAX not available"), + ), + pytest.param( + "jax", + "gpu", + marks=pytest.mark.skipif(not JAX_GPU_AVAILABLE, reason="JAX GPU not available"), + ), + ], +) +@pytest.mark.parametrize( + "shape", + [(2, 3), (5,)], +) +def test_zeros_frameworks(framework: str, device: str, shape: tuple[int, ...]) -> None: + """Test zeros function for all frameworks and devices.""" + zeros_arr = iop.zeros(framework=SupportedFrameworks(framework), shape=shape, device=SupportedDevices(device)) + expected = np.zeros(shape) + + assert_shapes_equal(zeros_arr, expected, framework) + assert_arrays_equal(zeros_arr, expected, framework) + assert_same_type(zeros_arr, framework) + + +# ============================================================================ +# Tests for get_item and set_item +# ============================================================================ + + +@pytest.mark.parametrize( + ("framework", "device"), + [ + ("numpy", "cpu"), + pytest.param( + "torch", + "cpu", + marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available"), + ), + pytest.param( + "torch", + "gpu", + marks=pytest.mark.skipif(not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available"), + ), + ], +) +def test_get_set_item_frameworks(framework: str, device: str) -> None: + """Test get_item and set_item functions for all frameworks and devices.""" + data = [[1, 2, 3], [4, 5, 6]] + arr = create_array(data, framework, device) + + # Test get_item + item = iop.get_item(arr, (0, 1)) + assert iop.to_numpy(item) == 2 + + # Test set_item + val = create_array(99, framework, device) + iop.set_item(arr, (0, 1), val) + item = iop.get_item(arr, (0, 1)) + assert iop.to_numpy(item) == 99 + + +# ============================================================================ +# Tests for astype +# ============================================================================ + + +@pytest.mark.parametrize( + ("framework", "device"), + [ + ("numpy", "cpu"), + pytest.param( + "torch", + "cpu", + marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available"), + ), + pytest.param( + "torch", + "gpu", + marks=pytest.mark.skipif(not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available"), + ), + pytest.param( + "tensorflow", + "cpu", + marks=pytest.mark.skipif(not TF_AVAILABLE, reason="TensorFlow not available"), + ), + pytest.param( + "tensorflow", + "gpu", + marks=pytest.mark.skipif(not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available"), + ), + pytest.param( + "jax", + "cpu", + marks=pytest.mark.skipif(not JAX_AVAILABLE, reason="JAX not available"), + ), + pytest.param( + "jax", + "gpu", + marks=pytest.mark.skipif(not JAX_GPU_AVAILABLE, reason="JAX GPU not available"), + ), + ], +) +@pytest.mark.parametrize( + ("to_type", "expected_val"), + [(int, 5), (float, 5.0), (bool, True)], +) +def test_astype_frameworks(framework: str, device: str, to_type: type, expected_val: Any) -> None: + """Test astype function for all frameworks and devices.""" + arr = create_array([5.0], framework, device) + val = iop.astype(arr, to_type) + assert val == expected_val + assert isinstance(val, to_type) + + +# ============================================================================ +# Tests for norm +# ============================================================================ + + +@pytest.mark.parametrize( + ("framework", "device"), + [ + ("numpy", "cpu"), + pytest.param( + "torch", + "cpu", + marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available"), + ), + pytest.param( + "torch", + "gpu", + marks=pytest.mark.skipif(not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available"), + ), + pytest.param( + "tensorflow", + "cpu", + marks=pytest.mark.skipif(not TF_AVAILABLE, reason="TensorFlow not available"), + ), + pytest.param( + "tensorflow", + "gpu", + marks=pytest.mark.skipif(not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available"), + ), + pytest.param( + "jax", + "cpu", + marks=pytest.mark.skipif(not JAX_AVAILABLE, reason="JAX not available"), + ), + pytest.param( + "jax", + "gpu", + marks=pytest.mark.skipif(not JAX_GPU_AVAILABLE, reason="JAX GPU not available"), + ), + ], +) +@pytest.mark.parametrize( + ("p_norm", "data"), + [(2, [3.0, 4.0]), (1, [3.0, 4.0]), (1, [[1.0, -2.0], [-3.0, 4.0]]), (2, [[1.0, -2.0], [-3.0, 4.0]])], +) +def test_norm_frameworks(framework: str, device: str, p_norm: int, data: list) -> None: + """Test norm function for all frameworks and devices.""" + arr = create_array(data, framework, device) + norm_val = iop.norm(arr, p=p_norm) + + np_arr = create_array(data, "numpy") + expected = np.linalg.norm(np_arr, ord=p_norm) + + assert_arrays_equal(norm_val, expected, framework) + assert_same_type(norm_val, framework) + + +# ============================================================================ +# Tests for framework_device_of_array +# ============================================================================ + + +@pytest.mark.parametrize( + ("framework", "device"), + [ + ("numpy", "cpu"), + pytest.param( + "torch", + "cpu", + marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available"), + ), + pytest.param( + "torch", + "gpu", + marks=pytest.mark.skipif(not TORCH_CUDA_AVAILABLE, reason="PyTorch CUDA not available"), + ), + pytest.param( + "tensorflow", + "cpu", + marks=pytest.mark.skipif(not TF_AVAILABLE, reason="TensorFlow not available"), + ), + pytest.param( + "tensorflow", + "gpu", + marks=pytest.mark.skipif(not TF_GPU_AVAILABLE, reason="TensorFlow GPU not available"), + ), + pytest.param( + "jax", + "cpu", + marks=pytest.mark.skipif(not JAX_AVAILABLE, reason="JAX not available"), + ), + pytest.param( + "jax", + "gpu", + marks=pytest.mark.skipif(not JAX_GPU_AVAILABLE, reason="JAX GPU not available"), + ), + ], +) +def test_framework_device_of_array(framework: str, device: str) -> None: + """Test framework_device_of_array function for all frameworks and devices.""" + arr = create_array([1.0, 2.0, 3.0], framework, device) + fw, dev = iop.framework_device_of_array(arr) + + assert fw == SupportedFrameworks(framework), f"Expected framework {framework}, got {fw}" + assert dev == SupportedDevices(device), f"Expected device {device}, got {dev}" From 96a89e62fb7efec6ede24502501f3436f905ad3e Mon Sep 17 00:00:00 2001 From: Adriana Rodriguez Date: Fri, 12 Dec 2025 13:42:16 +0100 Subject: [PATCH 03/16] fix(networks): resolve lint and doc warning Organize imports and remove extraneous whitespace in networks.py to satisfy ruff. Update user guide to reference Network.active_agents so Sphinx can find the target. --- decent_bench/networks.py | 4 +--- docs/source/user.rst | 6 +++--- test/utils/test_array.py | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/decent_bench/networks.py b/decent_bench/networks.py index a81095d..1d22fe4 100644 --- a/decent_bench/networks.py +++ b/decent_bench/networks.py @@ -6,7 +6,6 @@ import networkx as nx import numpy as np -from networkx import Graph import decent_bench.utils.interoperability as iop from decent_bench.agents import Agent @@ -187,12 +186,11 @@ def adjacency(self) -> Array: A[i, j] = 1 return iop.to_array(A, agents[0].cost.framework, agents[0].cost.device) - + def neighbors(self, agent: Agent) -> list[Agent]: """Get all neighbors of an agent.""" return list(self.graph[agent]) - def broadcast(self, sender: Agent, msg: Array) -> None: """ Send message to all neighbors. diff --git a/docs/source/user.rst b/docs/source/user.rst index 0ee78a8..19fac96 100644 --- a/docs/source/user.rst +++ b/docs/source/user.rst @@ -274,8 +274,8 @@ Create a new algorithm to benchmark against existing ones. **Note**: In order for metrics to work, use :attr:`Agent.x ` to update the local primal variable. Similarly, in order for the benchmark problem's communication schemes to be applied, use the -:attr:`~decent_bench.networks.P2PNetwork` object to retrieve agents and to send and receive messages. -Be sure to use :meth:`~decent_bench.networks.P2PNetwork.active_agents` to during algorithm runtime, so that asynchrony is properly handled. +:attr:`~decent_bench.networks.P2PNetwork`/ :attr:`~decent_bench.networks.FedNetwork` object to retrieve agents and to send and receive messages. +Be sure to use :meth:`~decent_bench.networks.Network.active_agents` during algorithm runtime, so that asynchrony is properly handled. .. code-block:: python @@ -451,4 +451,4 @@ compatibility with the selected framework and device of your custom cost. if self.shape != other.shape: raise ValueError(f"Mismatching domain shapes: {self.shape} vs {other.shape}") - return SumCost([self, other]) \ No newline at end of file + return SumCost([self, other]) diff --git a/test/utils/test_array.py b/test/utils/test_array.py index 2ee7557..b14b762 100644 --- a/test/utils/test_array.py +++ b/test/utils/test_array.py @@ -68,7 +68,7 @@ def create_array(data: list, framework: str, device: str = "cpu"): if framework == "tensorflow": device_str = "/GPU:0" if device == "gpu" and TF_GPU_AVAILABLE else "/CPU:0" with tf.device(device_str): - array2: tf.Tensor = tf.constant(data, dtype=tf.float32) # type: ignore + array2: tf.Tensor = tf.constant(data, dtype=tf.float64) # type: ignore return Array(array2) elif framework == "jax": array3 = jnp.array(data, dtype=jnp.float32) From e6fce231d57cbaf519b1a5db2656affd61859059 Mon Sep 17 00:00:00 2001 From: Adriana Rodriguez Date: Sat, 13 Dec 2025 14:51:57 +0100 Subject: [PATCH 04/16] enh(networks): address review feedback (#229) Adjust kind to return the concrete class and simplify FL send_all validation. --- decent_bench/networks.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/decent_bench/networks.py b/decent_bench/networks.py index 1d22fe4..6388c84 100644 --- a/decent_bench/networks.py +++ b/decent_bench/networks.py @@ -55,9 +55,9 @@ def message_drop(self) -> DropScheme: return self._message_drop @abstractmethod - def kind(self) -> str: - """Label for the network subtype (e.g., 'p2p', 'fed').""" - raise NotImplementedError + def kind(self) -> type["Network"]: + """Concrete network class for type-based dispatch.""" + ... def agents(self) -> list[Agent]: """Get all agents in the network.""" @@ -124,9 +124,9 @@ def __init__( ) self.W: Array | None = None - def kind(self) -> str: - """Label for the network subtype.""" - return "p2p" + def kind(self) -> type["Network"]: + """Concrete network class for type-based dispatch.""" + return type(self) def set_weights(self, weights: Array) -> None: """ @@ -245,9 +245,9 @@ def _identify_server(self) -> Agent: raise ValueError("FedNetwork expects a star topology with one server connected to all clients") return server - def kind(self) -> str: - """Label for the network subtype.""" - return "fed" + def kind(self) -> type["Network"]: + """Concrete network class for type-based dispatch.""" + return type(self) @property def server(self) -> Agent: @@ -308,10 +308,11 @@ def send_from_all_clients(self, msgs: Mapping[Agent, Array]) -> None: """ clients = set(self.clients) - invalid = [client for client in msgs if client not in clients] + senders = set(msgs) + invalid = senders - clients if invalid: raise ValueError("All senders must be clients") - missing = clients - set(msgs) + missing = clients - senders if missing: raise ValueError("Messages must be provided for all clients") for client, msg in msgs.items(): From 6973a75dd9ea5ee6211a76904f650664a81d3f29 Mon Sep 17 00:00:00 2001 From: Adriana Rodriguez Date: Thu, 18 Dec 2025 20:19:34 +0100 Subject: [PATCH 05/16] ref(networks): align APIs with PR feedback (#229) Unify send/receive with optional None/iterables and keep aliases in P2P; enforce role-aware Fed send/receive and use client-only agents; remove unused accessors/kind, add doc notes on graph mutability and Fed agent semantics; update docs accordingly. closes #192 --- decent_bench/networks.py | 263 ++++++++++++++++++++++++++++----------- docs/source/user.rst | 2 + 2 files changed, 190 insertions(+), 75 deletions(-) diff --git a/decent_bench/networks.py b/decent_bench/networks.py index 6388c84..62c4db0 100644 --- a/decent_bench/networks.py +++ b/decent_bench/networks.py @@ -1,7 +1,6 @@ -from abc import ABC, abstractmethod -from collections.abc import Mapping +from abc import ABC +from collections.abc import Iterable, Mapping from functools import cached_property -from operator import itemgetter from typing import TYPE_CHECKING import networkx as nx @@ -19,8 +18,8 @@ AgentGraph = nx.Graph -class Network(ABC): - """Base network that defines the communication constraints shared by all network types.""" +class Network(ABC): # noqa: B024 + """Base network object defining communication logic shared by all network types.""" def __init__( self, @@ -36,29 +35,9 @@ def __init__( @property def graph(self) -> AgentGraph: - """Underlying agent graph.""" + """Underlying NetworkX graph; mutating it will change the network.""" return self._graph - @property - def message_noise(self) -> NoiseScheme: - """Noise scheme applied to messages.""" - return self._message_noise - - @property - def message_compression(self) -> CompressionScheme: - """Compression scheme applied to messages.""" - return self._message_compression - - @property - def message_drop(self) -> DropScheme: - """Drop scheme applied to messages.""" - return self._message_drop - - @abstractmethod - def kind(self) -> type["Network"]: - """Concrete network class for type-based dispatch.""" - ... - def agents(self) -> list[Agent]: """Get all agents in the network.""" return list(self.graph) @@ -72,9 +51,13 @@ def active_agents(self, iteration: int) -> list[Agent]: """ return [a for a in self.agents() if a._activation.is_active(iteration)] # noqa: SLF001 - def send(self, sender: Agent, receiver: Agent, msg: Array) -> None: + def _adjacent_agents(self, agent: Agent) -> list[Agent]: + """Agents adjacent to ``agent`` in the underlying graph.""" + return list(self.graph.neighbors(agent)) + + def _send_one(self, sender: Agent, receiver: Agent, msg: Array) -> None: """ - Send message to a neighbor. + Send message to an agent. The message may be compressed, distorted by noise, and/or dropped depending on the network's :class:`~decent_bench.schemes.CompressionScheme`, @@ -85,16 +68,47 @@ def send(self, sender: Agent, receiver: Agent, msg: Array) -> None: same receiver. After being received or replaced, the message is destroyed. """ sender._n_sent_messages += 1 # noqa: SLF001 - if self.message_drop.should_drop(): + if self._message_drop.should_drop(): sender._n_sent_messages_dropped += 1 # noqa: SLF001 return - msg = self.message_compression.compress(msg) - msg = self.message_noise.make_noise(msg) + msg = self._message_compression.compress(msg) + msg = self._message_noise.make_noise(msg) self.graph.edges[sender, receiver][str(receiver.id)] = msg - def receive(self, receiver: Agent, sender: Agent) -> None: + def send( + self, + sender: Agent, + receiver: Agent | Iterable[Agent] | None = None, + msg: Array | None = None, + ) -> None: """ - Receive message from a neighbor. + Send message to one or more agents. + + Args: + sender: sender agent + receiver: receiver agent, iterable of receiver agents, or ``None`` to broadcast to adjacent agents. + msg: message to send + + Raises: + ValueError: if ``msg`` is not provided. + + """ + if msg is None: + raise ValueError("msg must be provided") + + if receiver is None: + receivers = self._adjacent_agents(sender) + elif isinstance(receiver, Agent): + receivers = [receiver] + else: + receivers = list(receiver) + + for r in receivers: + self._send_one(sender=sender, receiver=r, msg=msg) + + def _receive_one(self, receiver: Agent, sender: Agent) -> None: + """ + Receive message from an agent. Received messages are stored in :attr:`Agent.messages `. @@ -105,9 +119,28 @@ def receive(self, receiver: Agent, sender: Agent) -> None: receiver._received_messages[sender] = msg # noqa: SLF001 self.graph.edges[sender, receiver][str(receiver.id)] = None + def receive(self, receiver: Agent, sender: Agent | Iterable[Agent] | None = None) -> None: + """ + Receive message(s) at an agent. + + Args: + receiver: receiver agent + sender: sender agent, iterable of sender agents, or ``None`` to receive from all adjacent agents. + + """ + if sender is None: + senders = self._adjacent_agents(receiver) + elif isinstance(sender, Agent): + senders = [sender] + else: + senders = list(sender) + + for s in senders: + self._receive_one(receiver=receiver, sender=s) + class P2PNetwork(Network): - """Peer-to-peer network of agents that communicate by sending and receiving messages.""" + """Peer-to-peer network architecture where agents communicate directly with each other.""" def __init__( self, @@ -124,10 +157,6 @@ def __init__( ) self.W: Array | None = None - def kind(self) -> type["Network"]: - """Concrete network class for type-based dispatch.""" - return type(self) - def set_weights(self, weights: Array) -> None: """ Set custom consensus weights matrix. @@ -191,30 +220,16 @@ def neighbors(self, agent: Agent) -> list[Agent]: """Get all neighbors of an agent.""" return list(self.graph[agent]) - def broadcast(self, sender: Agent, msg: Array) -> None: - """ - Send message to all neighbors. + def _adjacent_agents(self, agent: Agent) -> list[Agent]: + return self.neighbors(agent) - The message may be compressed, distorted by noise, and/or dropped depending on the network's - :class:`~decent_bench.schemes.CompressionScheme`, - :class:`~decent_bench.schemes.NoiseScheme`, - and :class:`~decent_bench.schemes.DropScheme`. - - The message will stay in-flight until it is received or replaced by a newer message from the same sender to the - same receiver. After being received or replaced, the message is destroyed. - """ - for neighbor in self.neighbors(sender): - self.send(sender=sender, receiver=neighbor, msg=msg) + def broadcast(self, sender: Agent, msg: Array) -> None: + """Send to all neighbors (alias for :meth:`~decent_bench.networks.Network.send` with ``receiver=None``).""" + self.send(sender=sender, receiver=None, msg=msg) def receive_all(self, receiver: Agent) -> None: - """ - Receive messages from all neighbors. - - Received messages are stored in - :attr:`Agent.messages `. - """ - for neighbor in self.neighbors(receiver): - self.receive(receiver, neighbor) + """Receive from all neighbors (alias for Network.receive with sender=None).""" + self.receive(receiver=receiver, sender=None) class FedNetwork(Network): @@ -239,33 +254,133 @@ def _identify_server(self) -> Agent: degrees = dict(self.graph.degree()) if not degrees: raise ValueError("FedNetwork requires at least one agent") - server, max_degree = max(degrees.items(), key=itemgetter(1)) + server, max_degree = max(degrees.items(), key=lambda item: item[1]) # noqa: FURB118 n = len(degrees) if max_degree != n - 1 or any(deg != 1 for node, deg in degrees.items() if node != server): raise ValueError("FedNetwork expects a star topology with one server connected to all clients") return server - def kind(self) -> type["Network"]: - """Concrete network class for type-based dispatch.""" - return type(self) - @property def server(self) -> Agent: """Agent acting as the central server.""" return self._server @property - def clients(self) -> list[Agent]: - """Agents acting as clients.""" + def coordinator(self) -> Agent: + """Alias for :attr:`server`.""" + return self.server + + def agents(self) -> list[Agent]: + """Get all client agents (excludes the server/coordinator).""" return [agent for agent in self.graph if agent is not self.server] + def active_agents(self, iteration: int) -> list[Agent]: + """Get all active client agents (excludes the server/coordinator).""" + # Delegates to Network.active_agents(), which iterates over self.agents() (clients only for FedNetwork). + return super().active_agents(iteration) + + @property + def clients(self) -> list[Agent]: + """Alias for :meth:`agents`.""" + return self.agents() + def active_clients(self, iteration: int) -> list[Agent]: + """Alias for :meth:`active_agents`.""" + return self.active_agents(iteration) + + def send( + self, + sender: Agent, + receiver: Agent | Iterable[Agent] | None = None, + msg: Array | None = None, + ) -> None: """ - Get all active clients (excludes the server). + Send message(s) in a federated learning network. + + Only server <-> client communication is allowed. Client-to-client and server-to-server communication will + raise an error. + + Raises: + ValueError: if msg is missing or if sender/receiver roles are invalid. - Uses :meth:`Network.active_agents` to honor activation schemes and then filters out the server. """ - return [agent for agent in self.active_agents(iteration) if agent is not self.server] + if msg is None: + raise ValueError("msg must be provided") + + if sender not in self.graph: + raise ValueError("Sender must be an agent in the network") + + if receiver is None: + if sender is self.server: + super().send(sender=sender, receiver=self.clients, msg=msg) # server -> clients + return + super().send(sender=sender, receiver=self.server, msg=msg) # client -> server + return + + if isinstance(receiver, Agent): + if receiver not in self.graph: + raise ValueError("Receiver must be an agent in the network") + + if sender is self.server: + if receiver is self.server: + raise ValueError("Server-to-server communication is not supported") + super().send(sender=sender, receiver=receiver, msg=msg) # server -> client + return + + if receiver is not self.server: + raise ValueError("Client-to-client communication is not supported") + super().send(sender=sender, receiver=receiver, msg=msg) # client -> server + return + + receivers = list(receiver) + if sender is not self.server: + raise ValueError("Only the server can send to multiple receivers") + if any(r not in self.graph or r is self.server for r in receivers): + raise ValueError("All receivers must be clients") + super().send(sender=sender, receiver=receivers, msg=msg) + + def receive(self, receiver: Agent, sender: Agent | Iterable[Agent] | None = None) -> None: + """ + Receive message(s) in a federated learning network. + + Only server <-> client communication is allowed. Client-to-client and server-to-server communication will + raise an error. + + Raises: + ValueError: if sender/receiver roles are invalid. + + """ + if receiver not in self.graph: + raise ValueError("Receiver must be an agent in the network") + + if sender is None: + if receiver is self.server: + super().receive(receiver=receiver, sender=self.clients) + return + super().receive(receiver=receiver, sender=self.server) + return + + if isinstance(sender, Agent): + if sender not in self.graph: + raise ValueError("Sender must be an agent in the network") + + if receiver is self.server: + if sender is self.server: + raise ValueError("Server-to-server communication is not supported") + super().receive(receiver=receiver, sender=sender) # server receives from a client + return + + if sender is not self.server: + raise ValueError("Client-to-client communication is not supported") + super().receive(receiver=receiver, sender=sender) # client receives from the server + return + + senders = list(sender) + if receiver is not self.server: + raise ValueError("Only the server can receive from multiple senders") + if any(s not in self.graph or s is self.server for s in senders): + raise ValueError("All senders must be clients") + super().receive(receiver=receiver, sender=senders) def send_to_client(self, client: Agent, msg: Array) -> None: """ @@ -281,8 +396,7 @@ def send_to_client(self, client: Agent, msg: Array) -> None: def send_to_all_clients(self, msg: Array) -> None: """Send the same message from the server to every client (synchronous FL push).""" - for client in self.clients: - self.send_to_client(client, msg) + self.send(sender=self.server, receiver=None, msg=msg) def send_from_client(self, client: Agent, msg: Array) -> None: """ @@ -328,7 +442,7 @@ def receive_at_client(self, client: Agent) -> None: """ if client not in self.clients: raise ValueError("Receiver must be a client") - self.receive(receiver=client, sender=self.server) + self.receive(receiver=client, sender=None) # or self.receive(receiver=client, sender=self.server) def receive_at_all_clients(self) -> None: """Receive messages at every client from the server (synchronous FL pull).""" @@ -349,8 +463,7 @@ def receive_from_client(self, client: Agent) -> None: def receive_from_all_clients(self) -> None: """Receive messages at the server from every client (synchronous FL pull).""" - for client in self.clients: - self.receive_from_client(client) + self.receive(receiver=self.server, sender=None) def create_distributed_network(problem: BenchmarkProblem) -> P2PNetwork: @@ -405,7 +518,7 @@ def create_federated_network(problem: BenchmarkProblem) -> FedNetwork: raise NotImplementedError("Support for disconnected graphs has not been implemented yet") degrees = dict(problem.network_structure.degree()) if n_agents: - server, max_degree = max(degrees.items(), key=itemgetter(1)) + server, max_degree = max(degrees.items(), key=lambda item: item[1]) # noqa: FURB118 if max_degree != n_agents - 1 or any(deg != 1 for node, deg in degrees.items() if node != server): raise ValueError("Federated network requires a star topology (one server connected to all clients)") agents = [Agent(i, problem.costs[i], problem.agent_activations[i]) for i in range(n_agents)] diff --git a/docs/source/user.rst b/docs/source/user.rst index 3a6b215..9b50189 100644 --- a/docs/source/user.rst +++ b/docs/source/user.rst @@ -304,6 +304,8 @@ variable **once** every iteration. If you need to perform multiple updates withi Similarly, in order for the benchmark problem's communication schemes to be applied, use the :attr:`~decent_bench.networks.P2PNetwork`/ :attr:`~decent_bench.networks.FedNetwork` object to retrieve agents and to send and receive messages. Be sure to use :meth:`~decent_bench.networks.Network.active_agents` during algorithm runtime so that asynchrony is properly handled. +You can also inspect :attr:`~decent_bench.networks.Network.graph` to use NetworkX utilities (e.g., plotting or listing edges); mutating this graph changes the network topology. +In :class:`~decent_bench.networks.FedNetwork`, :meth:`~decent_bench.networks.Network.agents` and :meth:`~decent_bench.networks.Network.active_agents` refer to clients (the server is available via :attr:`~decent_bench.networks.FedNetwork.server`/ :attr:`~decent_bench.networks.FedNetwork.coordinator`). .. code-block:: python From 38075db35f0bf07a96f5b0415bbc396e29033300 Mon Sep 17 00:00:00 2001 From: Adriana Rodriguez Date: Fri, 19 Dec 2025 01:53:02 +0100 Subject: [PATCH 06/16] enh(networks): address review feedback (#229) Iterate receiver iterable directly in Network.send/receive --- decent_bench/networks.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/decent_bench/networks.py b/decent_bench/networks.py index 62c4db0..89d5402 100644 --- a/decent_bench/networks.py +++ b/decent_bench/networks.py @@ -96,12 +96,13 @@ def send( if msg is None: raise ValueError("msg must be provided") + receivers: Iterable[Agent] if receiver is None: receivers = self._adjacent_agents(sender) elif isinstance(receiver, Agent): receivers = [receiver] else: - receivers = list(receiver) + receivers = receiver for r in receivers: self._send_one(sender=sender, receiver=r, msg=msg) @@ -128,12 +129,13 @@ def receive(self, receiver: Agent, sender: Agent | Iterable[Agent] | None = None sender: sender agent, iterable of sender agents, or ``None`` to receive from all adjacent agents. """ + senders: Iterable[Agent] if sender is None: senders = self._adjacent_agents(receiver) elif isinstance(sender, Agent): senders = [sender] else: - senders = list(sender) + senders = sender for s in senders: self._receive_one(receiver=receiver, sender=s) @@ -442,7 +444,7 @@ def receive_at_client(self, client: Agent) -> None: """ if client not in self.clients: raise ValueError("Receiver must be a client") - self.receive(receiver=client, sender=None) # or self.receive(receiver=client, sender=self.server) + self.receive(receiver=client, sender=None) def receive_at_all_clients(self) -> None: """Receive messages at every client from the server (synchronous FL pull).""" From 0b2acca9685af868775cc40bf6d6b3db63dfc66f Mon Sep 17 00:00:00 2001 From: Adriana Rodriguez Date: Fri, 19 Dec 2025 18:38:00 +0100 Subject: [PATCH 07/16] ref(networks): Enforce connectivity checks in send/receive (#229) Centralize sender/receiver connectivity validation in Network.send/receive and expose connected_agents for all network types. Simplify FedNetwork send/receive to rely on the base checks while keeping FL-specific errors and fan-out guards. Remove the duplicated output section in user.rst. --- decent_bench/networks.py | 114 +++++++++++++++++++-------------------- docs/source/user.rst | 10 ---- 2 files changed, 55 insertions(+), 69 deletions(-) diff --git a/decent_bench/networks.py b/decent_bench/networks.py index 89d5402..f5ede2f 100644 --- a/decent_bench/networks.py +++ b/decent_bench/networks.py @@ -38,6 +38,11 @@ def graph(self) -> AgentGraph: """Underlying NetworkX graph; mutating it will change the network.""" return self._graph + @property + def G(self) -> AgentGraph: # noqa: N802 + """Alias for the underlying graph.""" + return self.graph + def agents(self) -> list[Agent]: """Get all agents in the network.""" return list(self.graph) @@ -51,8 +56,8 @@ def active_agents(self, iteration: int) -> list[Agent]: """ return [a for a in self.agents() if a._activation.is_active(iteration)] # noqa: SLF001 - def _adjacent_agents(self, agent: Agent) -> list[Agent]: - """Agents adjacent to ``agent`` in the underlying graph.""" + def connected_agents(self, agent: Agent) -> list[Agent]: + """Agents directly connected to ``agent`` in the underlying graph.""" return list(self.graph.neighbors(agent)) def _send_one(self, sender: Agent, receiver: Agent, msg: Array) -> None: @@ -86,24 +91,32 @@ def send( Args: sender: sender agent - receiver: receiver agent, iterable of receiver agents, or ``None`` to broadcast to adjacent agents. + receiver: receiver agent, iterable of receiver agents, or ``None`` to broadcast to connected agents. msg: message to send Raises: - ValueError: if ``msg`` is not provided. + ValueError: if ``msg`` is not provided, if agents are not part of the network, or if sender/receiver are not + connected. """ if msg is None: raise ValueError("msg must be provided") - receivers: Iterable[Agent] + if sender not in self.graph: + raise ValueError("Sender must be an agent in the network") + + receivers: Iterable[Agent] | list[Agent] if receiver is None: - receivers = self._adjacent_agents(sender) + receivers = self.connected_agents(sender) elif isinstance(receiver, Agent): receivers = [receiver] else: receivers = receiver + receivers = list(receivers) + if any(r not in self.connected_agents(sender) for r in receivers): + raise ValueError("Sender and receiver must be connected in the network") + for r in receivers: self._send_one(sender=sender, receiver=r, msg=msg) @@ -126,17 +139,27 @@ def receive(self, receiver: Agent, sender: Agent | Iterable[Agent] | None = None Args: receiver: receiver agent - sender: sender agent, iterable of sender agents, or ``None`` to receive from all adjacent agents. + sender: sender agent, iterable of sender agents, or ``None`` to receive from all connected agents. + + Raises: + ValueError: if sender/receiver are not part of the network or not connected. """ - senders: Iterable[Agent] + if receiver not in self.graph: + raise ValueError("Receiver must be an agent in the network") + + senders: Iterable[Agent] | list[Agent] if sender is None: - senders = self._adjacent_agents(receiver) + senders = self.connected_agents(receiver) elif isinstance(sender, Agent): senders = [sender] else: senders = sender + senders = list(senders) + if any(s not in self.connected_agents(receiver) for s in senders): + raise ValueError("Sender and receiver must be connected in the network") + for s in senders: self._receive_one(receiver=receiver, sender=s) @@ -222,7 +245,8 @@ def neighbors(self, agent: Agent) -> list[Agent]: """Get all neighbors of an agent.""" return list(self.graph[agent]) - def _adjacent_agents(self, agent: Agent) -> list[Agent]: + def connected_agents(self, agent: Agent) -> list[Agent]: + """Agents directly connected to ``agent`` (alias for :meth:`neighbors`).""" return self.neighbors(agent) def broadcast(self, sender: Agent, msg: Array) -> None: @@ -303,41 +327,26 @@ def send( raise an error. Raises: - ValueError: if msg is missing or if sender/receiver roles are invalid. + ValueError: if server-to-server or client-to-client communication is attempted, or if a non-server tries to + send to multiple receivers. Also see :meth:`Network.send` for generic validation. """ - if msg is None: - raise ValueError("msg must be provided") - - if sender not in self.graph: - raise ValueError("Sender must be an agent in the network") - - if receiver is None: - if sender is self.server: - super().send(sender=sender, receiver=self.clients, msg=msg) # server -> clients - return - super().send(sender=sender, receiver=self.server, msg=msg) # client -> server - return - if isinstance(receiver, Agent): - if receiver not in self.graph: - raise ValueError("Receiver must be an agent in the network") - - if sender is self.server: - if receiver is self.server: - raise ValueError("Server-to-server communication is not supported") - super().send(sender=sender, receiver=receiver, msg=msg) # server -> client - return - - if receiver is not self.server: + if sender is self.server and receiver is self.server: + raise ValueError("Server-to-server communication is not supported") + if sender is not self.server and receiver is not self.server: raise ValueError("Client-to-client communication is not supported") - super().send(sender=sender, receiver=receiver, msg=msg) # client -> server + super().send(sender=sender, receiver=receiver, msg=msg) + return + + if receiver is None: + super().send(sender=sender, receiver=receiver, msg=msg) return receivers = list(receiver) if sender is not self.server: raise ValueError("Only the server can send to multiple receivers") - if any(r not in self.graph or r is self.server for r in receivers): + if any(r is self.server for r in receivers): raise ValueError("All receivers must be clients") super().send(sender=sender, receiver=receivers, msg=msg) @@ -349,38 +358,25 @@ def receive(self, receiver: Agent, sender: Agent | Iterable[Agent] | None = None raise an error. Raises: - ValueError: if sender/receiver roles are invalid. + ValueError: if sender/receiver roles are invalid. Also see :meth:`Network.receive` for generic validation. """ - if receiver not in self.graph: - raise ValueError("Receiver must be an agent in the network") - - if sender is None: - if receiver is self.server: - super().receive(receiver=receiver, sender=self.clients) - return - super().receive(receiver=receiver, sender=self.server) - return - if isinstance(sender, Agent): - if sender not in self.graph: - raise ValueError("Sender must be an agent in the network") - - if receiver is self.server: - if sender is self.server: - raise ValueError("Server-to-server communication is not supported") - super().receive(receiver=receiver, sender=sender) # server receives from a client - return - - if sender is not self.server: + if receiver is self.server and sender is self.server: + raise ValueError("Server-to-server communication is not supported") + if receiver is not self.server and sender is not self.server: raise ValueError("Client-to-client communication is not supported") - super().receive(receiver=receiver, sender=sender) # client receives from the server + super().receive(receiver=receiver, sender=sender) + return + + if sender is None: + super().receive(receiver=receiver, sender=sender) return senders = list(sender) if receiver is not self.server: raise ValueError("Only the server can receive from multiple senders") - if any(s not in self.graph or s is self.server for s in senders): + if any(s is self.server for s in senders): raise ValueError("All senders must be clients") super().receive(receiver=receiver, sender=senders) diff --git a/docs/source/user.rst b/docs/source/user.rst index 9b50189..d932556 100644 --- a/docs/source/user.rst +++ b/docs/source/user.rst @@ -43,16 +43,6 @@ Benchmark executions will have outputs like these: :align: center -Benchmark executions will have outputs like these: - -.. list-table:: - - * - .. image:: _static/table.png - :align: center - - .. image:: _static/plot.png - :align: center - - Execution settings ------------------ Configure settings for metrics, trials, statistical confidence level, logging, and multiprocessing. From 2e596022bd534b2ffa8644cd874662a005c12958 Mon Sep 17 00:00:00 2001 From: Adriana Rodriguez Date: Fri, 19 Dec 2025 19:20:46 +0100 Subject: [PATCH 08/16] enh(networks): Alias neighbors to base connected_agents (#229) Delegate P2PNetwork.neighbors to Network.connected_agents --- decent_bench/networks.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/decent_bench/networks.py b/decent_bench/networks.py index f5ede2f..c43db1b 100644 --- a/decent_bench/networks.py +++ b/decent_bench/networks.py @@ -242,12 +242,8 @@ def adjacency(self) -> Array: return iop.to_array(A, agents[0].cost.framework, agents[0].cost.device) def neighbors(self, agent: Agent) -> list[Agent]: - """Get all neighbors of an agent.""" - return list(self.graph[agent]) - - def connected_agents(self, agent: Agent) -> list[Agent]: - """Agents directly connected to ``agent`` (alias for :meth:`neighbors`).""" - return self.neighbors(agent) + """Alias for :meth:`~decent_bench.networks.Network.connected_agents`.""" + return super().connected_agents(agent) def broadcast(self, sender: Agent, msg: Array) -> None: """Send to all neighbors (alias for :meth:`~decent_bench.networks.Network.send` with ``receiver=None``).""" From 3ff5876f100a9edcea9c2087b82380e04913aa9e Mon Sep 17 00:00:00 2001 From: Adriana Rodriguez Date: Fri, 19 Dec 2025 21:50:00 +0100 Subject: [PATCH 09/16] feat(networks): expose networkx helpers and plotting Add degrees/edges accessors and plotting helper to networks. Use networkx to build adjacency for P2P networks and allow auto-plot via benchmark creation flags. Document plotting options in user guide. closes #206 --- decent_bench/benchmark_problem.py | 10 +++++ decent_bench/networks.py | 74 ++++++++++++++++++++++++++----- docs/source/user.rst | 5 +++ 3 files changed, 79 insertions(+), 10 deletions(-) diff --git a/decent_bench/benchmark_problem.py b/decent_bench/benchmark_problem.py index ad7377a..23f4bb7 100644 --- a/decent_bench/benchmark_problem.py +++ b/decent_bench/benchmark_problem.py @@ -45,6 +45,8 @@ class BenchmarkProblem: message_compression: message compression setting message_noise: message noise setting message_drop: message drops setting + plot_network: plot the network when it is created (optional) + plot_network_kwargs: kwargs forwarded to :meth:`decent_bench.networks.Network.plot` """ @@ -55,6 +57,8 @@ class BenchmarkProblem: message_compression: CompressionScheme message_noise: NoiseScheme message_drop: DropScheme + plot_network: bool = False + plot_network_kwargs: dict[str, Any] | None = None def create_regression_problem( @@ -66,6 +70,8 @@ def create_regression_problem( compression: bool = False, noise: bool = False, drops: bool = False, + plot_network: bool = False, + plot_network_kwargs: dict[str, Any] | None = None, ) -> BenchmarkProblem: """ Create out-of-the-box regression problems. @@ -78,6 +84,8 @@ def create_regression_problem( compression: if true, messages are rounded to 4 significant digits noise: if true, messages are distorted by Gaussian noise drops: if true, messages have a 50% probability of being dropped + plot_network: if true, plot the network when it is created + plot_network_kwargs: kwargs forwarded to :meth:`decent_bench.networks.Network.plot` when ``plot_network`` is true """ network_structure = nx.random_regular_graph(n_neighbors_per_agent, n_agents, seed=0) @@ -105,4 +113,6 @@ def create_regression_problem( message_compression=message_compression, message_noise=message_noise, message_drop=message_drop, + plot_network=plot_network, + plot_network_kwargs=plot_network_kwargs, ) diff --git a/decent_bench/networks.py b/decent_bench/networks.py index c43db1b..fad6c0a 100644 --- a/decent_bench/networks.py +++ b/decent_bench/networks.py @@ -1,7 +1,7 @@ from abc import ABC from collections.abc import Iterable, Mapping from functools import cached_property -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import networkx as nx import numpy as np @@ -17,6 +17,14 @@ else: AgentGraph = nx.Graph +_LAYOUT_FUNCS: dict[str, Any] = { + "spring": nx.spring_layout, + "kamada_kawai": nx.kamada_kawai_layout, + "circular": nx.circular_layout, + "random": nx.random_layout, + "shell": nx.shell_layout, +} + class Network(ABC): # noqa: B024 """Base network object defining communication logic shared by all network types.""" @@ -47,6 +55,16 @@ def agents(self) -> list[Agent]: """Get all agents in the network.""" return list(self.graph) + @property + def degrees(self) -> dict[Agent, int]: + """Degree of each agent in the network.""" + return dict(self.graph.degree()) + + @property + def edges(self) -> list[tuple[Agent, Agent]]: + """Edges of the network as (agent, agent) tuples.""" + return list(self.graph.edges()) + def active_agents(self, iteration: int) -> list[Agent]: """ Get all active agents. @@ -60,6 +78,35 @@ def connected_agents(self, agent: Agent) -> list[Agent]: """Agents directly connected to ``agent`` in the underlying graph.""" return list(self.graph.neighbors(agent)) + def plot(self, ax: Any = None, layout: str = "spring", **draw_kwargs: Any) -> Any: + """ + Plot the network using NetworkX drawing utilities. + + Args: + ax: optional matplotlib Axes to draw on. If ``None`` a new figure is created. + layout: layout algorithm to position nodes (e.g. ``spring``, ``kamada_kawai``, ``circular``, ``random``, ``shell``). + draw_kwargs: forwarded to :func:`networkx.draw_networkx`. + + Returns: + The matplotlib Axes containing the plot. + """ + try: + import matplotlib.pyplot as plt + except Exception as exc: # pragma: no cover - runtime dependency guard + raise RuntimeError("matplotlib is required for plotting the network") from exc + + layout_func = _LAYOUT_FUNCS.get(layout) + if layout_func is None: + supported = ", ".join(sorted(_LAYOUT_FUNCS)) + raise ValueError(f"Unsupported layout '{layout}'. Supported layouts: {supported}") + + pos = layout_func(self.graph) + if ax is None: + _, ax = plt.subplots() + + nx.draw_networkx(self.graph, pos=pos, ax=ax, **draw_kwargs) + return ax + def _send_one(self, sender: Agent, receiver: Agent, msg: Array) -> None: """ Send message to an agent. @@ -233,13 +280,12 @@ def adjacency(self) -> Array: Use ``adjacency[i, j]`` or ``adjacency[i.id, j.id]`` to get the adjacency between agent i and j. """ agents = self.agents() - n = len(agents) - A = np.zeros((n, n)) # noqa: N806 - for i in agents: - for j in self.neighbors(i): - A[i, j] = 1 - - return iop.to_array(A, agents[0].cost.framework, agents[0].cost.device) + adjacency_matrix = nx.to_numpy_array(self.graph, nodelist=agents, dtype=float) + return iop.to_array( + adjacency_matrix, + agents[0].cost.framework, + agents[0].cost.device, + ) def neighbors(self, agent: Agent) -> list[Agent]: """Alias for :meth:`~decent_bench.networks.Network.connected_agents`.""" @@ -482,12 +528,16 @@ def create_distributed_network(problem: BenchmarkProblem) -> P2PNetwork: agents = [Agent(i, problem.costs[i], problem.agent_activations[i]) for i in range(n_agents)] agent_node_map = {node: agents[i] for i, node in enumerate(problem.network_structure.nodes())} graph = nx.relabel_nodes(problem.network_structure, agent_node_map) - return P2PNetwork( + network = P2PNetwork( graph=graph, message_noise=problem.message_noise, message_compression=problem.message_compression, message_drop=problem.message_drop, ) + if getattr(problem, "plot_network", False): + plot_kwargs = getattr(problem, "plot_network_kwargs", None) or {} + network.plot(**plot_kwargs) + return network def create_federated_network(problem: BenchmarkProblem) -> FedNetwork: @@ -518,9 +568,13 @@ def create_federated_network(problem: BenchmarkProblem) -> FedNetwork: agents = [Agent(i, problem.costs[i], problem.agent_activations[i]) for i in range(n_agents)] agent_node_map = {node: agents[i] for i, node in enumerate(problem.network_structure.nodes())} graph = nx.relabel_nodes(problem.network_structure, agent_node_map) - return FedNetwork( + network = FedNetwork( graph=graph, message_noise=problem.message_noise, message_compression=problem.message_compression, message_drop=problem.message_drop, ) + if getattr(problem, "plot_network", False): + plot_kwargs = getattr(problem, "plot_network_kwargs", None) or {} + network.plot(**plot_kwargs) + return network diff --git a/docs/source/user.rst b/docs/source/user.rst index d932556..ef0e8d6 100644 --- a/docs/source/user.rst +++ b/docs/source/user.rst @@ -94,6 +94,9 @@ Configure communication constraints and other settings for out-of-the-box regres compression=True, noise=True, drops=True, + # Optional: plot the network when it is created + plot_network=False, + plot_network_kwargs=None, ) if __name__ == "__main__": @@ -130,6 +133,8 @@ Change the settings of an already created benchmark problem, for example, the ne compression=True, noise=True, drops=True, + plot_network=True, + plot_network_kwargs={"layout": "circular", "with_labels": True}, ) problem.network_structure = nx.random_regular_graph(n_agents, n_neighbors_per_agent) From 8efe1899be691184508d4e8c556761586b7b32a7 Mon Sep 17 00:00:00 2001 From: Adriana Rodriguez Date: Sat, 20 Dec 2025 01:04:05 +0100 Subject: [PATCH 10/16] enh(networks): Detail connection validation errors (#229) Add sequence-based send/receive handling while preserving behavior. Report invalid agent ids in connection errors. --- decent_bench/networks.py | 66 ++++++++++++++++++++-------------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/decent_bench/networks.py b/decent_bench/networks.py index c43db1b..4eb780b 100644 --- a/decent_bench/networks.py +++ b/decent_bench/networks.py @@ -1,5 +1,5 @@ from abc import ABC -from collections.abc import Iterable, Mapping +from collections.abc import Mapping, Sequence from functools import cached_property from typing import TYPE_CHECKING @@ -83,7 +83,7 @@ def _send_one(self, sender: Agent, receiver: Agent, msg: Array) -> None: def send( self, sender: Agent, - receiver: Agent | Iterable[Agent] | None = None, + receiver: Agent | Sequence[Agent] | None = None, msg: Array | None = None, ) -> None: """ @@ -91,7 +91,7 @@ def send( Args: sender: sender agent - receiver: receiver agent, iterable of receiver agents, or ``None`` to broadcast to connected agents. + receiver: receiver agent, sequence of receiver agents, or ``None`` to broadcast to connected agents. msg: message to send Raises: @@ -105,19 +105,20 @@ def send( if sender not in self.graph: raise ValueError("Sender must be an agent in the network") - receivers: Iterable[Agent] | list[Agent] if receiver is None: - receivers = self.connected_agents(sender) + receiver = self.connected_agents(sender) elif isinstance(receiver, Agent): - receivers = [receiver] - else: - receivers = receiver - - receivers = list(receivers) - if any(r not in self.connected_agents(sender) for r in receivers): - raise ValueError("Sender and receiver must be connected in the network") + if receiver not in self.connected_agents(sender): + raise ValueError("Sender and receiver must be connected in the network") + self._send_one(sender=sender, receiver=receiver, msg=msg) + return + neighbors = set(self.connected_agents(sender)) + invalid_receivers = [r for r in receiver if r not in neighbors] + if invalid_receivers: + ids = [r.id for r in invalid_receivers] + raise ValueError(f"Sender and receiver must be connected in the network; not connected receivers: {ids}") - for r in receivers: + for r in receiver: self._send_one(sender=sender, receiver=r, msg=msg) def _receive_one(self, receiver: Agent, sender: Agent) -> None: @@ -133,13 +134,13 @@ def _receive_one(self, receiver: Agent, sender: Agent) -> None: receiver._received_messages[sender] = msg # noqa: SLF001 self.graph.edges[sender, receiver][str(receiver.id)] = None - def receive(self, receiver: Agent, sender: Agent | Iterable[Agent] | None = None) -> None: + def receive(self, receiver: Agent, sender: Agent | Sequence[Agent] | None = None) -> None: """ Receive message(s) at an agent. Args: receiver: receiver agent - sender: sender agent, iterable of sender agents, or ``None`` to receive from all connected agents. + sender: sender agent, sequence of sender agents, or ``None`` to receive from all connected agents. Raises: ValueError: if sender/receiver are not part of the network or not connected. @@ -148,19 +149,20 @@ def receive(self, receiver: Agent, sender: Agent | Iterable[Agent] | None = None if receiver not in self.graph: raise ValueError("Receiver must be an agent in the network") - senders: Iterable[Agent] | list[Agent] if sender is None: - senders = self.connected_agents(receiver) + sender = self.connected_agents(receiver) elif isinstance(sender, Agent): - senders = [sender] - else: - senders = sender - - senders = list(senders) - if any(s not in self.connected_agents(receiver) for s in senders): - raise ValueError("Sender and receiver must be connected in the network") + if sender not in self.connected_agents(receiver): + raise ValueError("Sender and receiver must be connected in the network") + self._receive_one(receiver=receiver, sender=sender) + return + neighbors = set(self.connected_agents(receiver)) + invalid_senders = [s for s in sender if s not in neighbors] + if invalid_senders: + ids = [s.id for s in invalid_senders] + raise ValueError(f"Sender and receiver must be connected in the network; not connected senders: {ids}") - for s in senders: + for s in sender: self._receive_one(receiver=receiver, sender=s) @@ -313,7 +315,7 @@ def active_clients(self, iteration: int) -> list[Agent]: def send( self, sender: Agent, - receiver: Agent | Iterable[Agent] | None = None, + receiver: Agent | Sequence[Agent] | None = None, msg: Array | None = None, ) -> None: """ @@ -339,14 +341,13 @@ def send( super().send(sender=sender, receiver=receiver, msg=msg) return - receivers = list(receiver) if sender is not self.server: raise ValueError("Only the server can send to multiple receivers") - if any(r is self.server for r in receivers): + if any(r is self.server for r in receiver): raise ValueError("All receivers must be clients") - super().send(sender=sender, receiver=receivers, msg=msg) + super().send(sender=sender, receiver=receiver, msg=msg) - def receive(self, receiver: Agent, sender: Agent | Iterable[Agent] | None = None) -> None: + def receive(self, receiver: Agent, sender: Agent | Sequence[Agent] | None = None) -> None: """ Receive message(s) in a federated learning network. @@ -369,12 +370,11 @@ def receive(self, receiver: Agent, sender: Agent | Iterable[Agent] | None = None super().receive(receiver=receiver, sender=sender) return - senders = list(sender) if receiver is not self.server: raise ValueError("Only the server can receive from multiple senders") - if any(s is self.server for s in senders): + if any(s is self.server for s in sender): raise ValueError("All senders must be clients") - super().receive(receiver=receiver, sender=senders) + super().receive(receiver=receiver, sender=sender) def send_to_client(self, client: Agent, msg: Array) -> None: """ From e9880e22069149095a5ea0ce766b221f893db7b6 Mon Sep 17 00:00:00 2001 From: Adriana Rodriguez Date: Sat, 20 Dec 2025 02:54:02 +0100 Subject: [PATCH 11/16] fix(networks): align plotting typing and docs config Adjust Network.plot to import matplotlib at call time and accept typed kwargs copied into the draw call. Cast adjacency conversion for agent nodes. Update docstring wrapping and Sphinx intersphinx/Axes alias to keep ruff, mypy, and sphinx clean. --- decent_bench/benchmark_problem.py | 3 ++- decent_bench/networks.py | 36 ++++++++++++++++++++++++------- docs/source/conf.py | 6 +++++- 3 files changed, 35 insertions(+), 10 deletions(-) diff --git a/decent_bench/benchmark_problem.py b/decent_bench/benchmark_problem.py index 23f4bb7..d372c75 100644 --- a/decent_bench/benchmark_problem.py +++ b/decent_bench/benchmark_problem.py @@ -85,7 +85,8 @@ def create_regression_problem( noise: if true, messages are distorted by Gaussian noise drops: if true, messages have a 50% probability of being dropped plot_network: if true, plot the network when it is created - plot_network_kwargs: kwargs forwarded to :meth:`decent_bench.networks.Network.plot` when ``plot_network`` is true + plot_network_kwargs: kwargs forwarded to :meth:`decent_bench.networks.Network.plot` when + ``plot_network`` is true """ network_structure = nx.random_regular_graph(n_neighbors_per_agent, n_agents, seed=0) diff --git a/decent_bench/networks.py b/decent_bench/networks.py index fad6c0a..f6846d3 100644 --- a/decent_bench/networks.py +++ b/decent_bench/networks.py @@ -1,7 +1,9 @@ +from __future__ import annotations + from abc import ABC -from collections.abc import Iterable, Mapping +from collections.abc import Collection, Iterable, Mapping from functools import cached_property -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import networkx as nx import numpy as np @@ -13,6 +15,8 @@ from decent_bench.utils.array import Array if TYPE_CHECKING: + from matplotlib.axes import Axes + AgentGraph = nx.Graph[Agent] else: AgentGraph = nx.Graph @@ -78,20 +82,26 @@ def connected_agents(self, agent: Agent) -> list[Agent]: """Agents directly connected to ``agent`` in the underlying graph.""" return list(self.graph.neighbors(agent)) - def plot(self, ax: Any = None, layout: str = "spring", **draw_kwargs: Any) -> Any: + def plot(self, ax: Axes | None = None, layout: str = "spring", **draw_kwargs: Mapping[str, object]) -> Axes: """ Plot the network using NetworkX drawing utilities. Args: ax: optional matplotlib Axes to draw on. If ``None`` a new figure is created. - layout: layout algorithm to position nodes (e.g. ``spring``, ``kamada_kawai``, ``circular``, ``random``, ``shell``). - draw_kwargs: forwarded to :func:`networkx.draw_networkx`. + layout: layout algorithm to position nodes (e.g. ``spring``, ``kamada_kawai``, ``circular``, + ``random``, ``shell``). + draw_kwargs: forwarded to ``networkx.draw_networkx``. Returns: The matplotlib Axes containing the plot. + + Raises: + RuntimeError: if matplotlib is not available. + ValueError: if an unsupported layout is requested. + """ try: - import matplotlib.pyplot as plt + import matplotlib.pyplot as plt # noqa: PLC0415 except Exception as exc: # pragma: no cover - runtime dependency guard raise RuntimeError("matplotlib is required for plotting the network") from exc @@ -104,7 +114,13 @@ def plot(self, ax: Any = None, layout: str = "spring", **draw_kwargs: Any) -> An if ax is None: _, ax = plt.subplots() - nx.draw_networkx(self.graph, pos=pos, ax=ax, **draw_kwargs) + draw_kwargs_dict: dict[str, Any] = dict(draw_kwargs) + nx.draw_networkx( + self.graph, + pos=pos, + ax=ax, + **draw_kwargs_dict, + ) return ax def _send_one(self, sender: Agent, receiver: Agent, msg: Array) -> None: @@ -280,7 +296,11 @@ def adjacency(self) -> Array: Use ``adjacency[i, j]`` or ``adjacency[i.id, j.id]`` to get the adjacency between agent i and j. """ agents = self.agents() - adjacency_matrix = nx.to_numpy_array(self.graph, nodelist=agents, dtype=float) + adjacency_matrix = nx.to_numpy_array( + cast("nx.Graph[Any]", self.graph), + nodelist=cast("Collection[Any]", agents), + dtype=float, + ) # type: ignore[call-overload] return iop.to_array( adjacency_matrix, agents[0].cost.framework, diff --git a/docs/source/conf.py b/docs/source/conf.py index 13368d3..7508859 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -30,7 +30,9 @@ autodoc_default_options = {} autodoc_member_order = "bysource" autodoc_preserve_defaults = True -autodoc_type_aliases = {} +autodoc_type_aliases = { + "Axes": "matplotlib.axes.Axes", +} nitpicky = True nitpick_ignore = [ @@ -42,6 +44,7 @@ ("py:class", "TorchTensor"), ("py:class", "TensorFlowTensor"), ("py:class", "JaxArray"), + ("py:class", "TypeAliasForwardRef"), ] intersphinx_mapping = { @@ -49,6 +52,7 @@ "numpy": ("https://numpy.org/doc/stable/", None), "python": ("https://docs.python.org/3", None), "torch": ("https://pytorch.org/docs/stable/", None), + "matplotlib": ("https://matplotlib.org/stable/", None), "tensorflow": ( "https://www.tensorflow.org/api_docs/python", "https://github.com/GPflow/tensorflow-intersphinx/raw/master/tf2_py_objects.inv", From d581924369d2dd3b5a6a5501eca054a2aa44f2da Mon Sep 17 00:00:00 2001 From: Adriana Rodriguez Date: Sat, 20 Dec 2025 10:22:34 +0100 Subject: [PATCH 12/16] docs(user): Document network plot kwargs and layouts Explain supported layouts and common draw kwargs passed via plot_network_kwargs and net.plot so users know which values to set. --- docs/source/user.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/user.rst b/docs/source/user.rst index ef0e8d6..e922089 100644 --- a/docs/source/user.rst +++ b/docs/source/user.rst @@ -150,6 +150,11 @@ Change the settings of an already created benchmark problem, for example, the ne ) +Notes on plotting +~~~~~~~~~~~~~~~~~ +``plot_network_kwargs`` are passed directly to ``networkx.draw_networkx``. Supported ``layout`` values are ``spring``, ``kamada_kawai``, ``circular``, ``random``, and ``shell``. Common kwargs include ``with_labels`` (``True``/``False``), ``labels`` (e.g. ``{agent: agent.id}``), ``node_color``, ``node_size``, and ``font_size``. Use ``plot_network=True`` to draw automatically at creation, or call ``net.plot(...)`` later with the same kwargs. + + Create problems using existing resources ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Create a custom benchmark problem using existing resources. From e5dcad2494b359077ef9810f9a5e18ce0dbe10d9 Mon Sep 17 00:00:00 2001 From: Adriana Rodriguez Date: Wed, 24 Dec 2025 11:47:37 +0100 Subject: [PATCH 13/16] wip(networkx-extras): Address PR review feedback --- decent_bench/networks.py | 23 +++++++++++++---------- docs/source/conf.py | 5 +---- docs/source/user.rst | 2 +- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/decent_bench/networks.py b/decent_bench/networks.py index 0117973..1120bd0 100644 --- a/decent_bench/networks.py +++ b/decent_bench/networks.py @@ -3,10 +3,11 @@ from abc import ABC from collections.abc import Collection, Mapping, Sequence from functools import cached_property -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Literal, cast import networkx as nx import numpy as np +from matplotlib.axes import Axes import decent_bench.utils.interoperability as iop from decent_bench.agents import Agent @@ -15,13 +16,13 @@ from decent_bench.utils.array import Array if TYPE_CHECKING: - from matplotlib.axes import Axes - AgentGraph = nx.Graph[Agent] else: AgentGraph = nx.Graph -_LAYOUT_FUNCS: dict[str, Any] = { +Layout = Literal["spring", "kamada_kawai", "circular", "random", "shell"] + +_LAYOUT_FUNCS: dict[Layout, Any] = { "spring": nx.spring_layout, "kamada_kawai": nx.kamada_kawai_layout, "circular": nx.circular_layout, @@ -82,15 +83,16 @@ def connected_agents(self, agent: Agent) -> list[Agent]: """Agents directly connected to ``agent`` in the underlying graph.""" return list(self.graph.neighbors(agent)) - def plot(self, ax: Axes | None = None, layout: str = "spring", **draw_kwargs: Mapping[str, object]) -> Axes: + def plot(self, ax: Axes | None = None, layout: Layout = "spring", **draw_kwargs: Mapping[str, object]) -> Axes: """ Plot the network using NetworkX drawing utilities. Args: ax: optional matplotlib Axes to draw on. If ``None`` a new figure is created. - layout: layout algorithm to position nodes (e.g. ``spring``, ``kamada_kawai``, ``circular``, - ``random``, ``shell``). - draw_kwargs: forwarded to ``networkx.draw_networkx``. + layout: layout algorithm to position nodes (e.g. :func:`networkx.spring_layout`, + :func:`networkx.kamada_kawai_layout`, :func:`networkx.circular_layout`, + :func:`networkx.random_layout`, :func:`networkx.shell_layout`). + draw_kwargs: forwarded to :func:`networkx.draw_networkx`. Returns: The matplotlib Axes containing the plot. @@ -278,11 +280,12 @@ def weights(self) -> Array: n = len(agents) W = np.zeros((n, n)) # noqa: N806 + degrees = self.degrees for i in agents: neighbors = self.neighbors(i) - d_i = len(neighbors) + d_i = degrees[i] for j in neighbors: - d_j = len(self.neighbors(j)) + d_j = degrees[j] W[i, j] = 1 / (1 + max(d_i, d_j)) for i in agents: W[i, i] = 1 - sum(W[i]) diff --git a/docs/source/conf.py b/docs/source/conf.py index 7508859..805bc62 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -30,9 +30,7 @@ autodoc_default_options = {} autodoc_member_order = "bysource" autodoc_preserve_defaults = True -autodoc_type_aliases = { - "Axes": "matplotlib.axes.Axes", -} +autodoc_type_aliases = {} nitpicky = True nitpick_ignore = [ @@ -44,7 +42,6 @@ ("py:class", "TorchTensor"), ("py:class", "TensorFlowTensor"), ("py:class", "JaxArray"), - ("py:class", "TypeAliasForwardRef"), ] intersphinx_mapping = { diff --git a/docs/source/user.rst b/docs/source/user.rst index e922089..3240b60 100644 --- a/docs/source/user.rst +++ b/docs/source/user.rst @@ -152,7 +152,7 @@ Change the settings of an already created benchmark problem, for example, the ne Notes on plotting ~~~~~~~~~~~~~~~~~ -``plot_network_kwargs`` are passed directly to ``networkx.draw_networkx``. Supported ``layout`` values are ``spring``, ``kamada_kawai``, ``circular``, ``random``, and ``shell``. Common kwargs include ``with_labels`` (``True``/``False``), ``labels`` (e.g. ``{agent: agent.id}``), ``node_color``, ``node_size``, and ``font_size``. Use ``plot_network=True`` to draw automatically at creation, or call ``net.plot(...)`` later with the same kwargs. +``plot_network_kwargs`` are passed directly to :func:`networkx.draw_networkx`. Supported ``layout`` values are :func:`networkx.spring_layout`, :func:`networkx.kamada_kawai_layout`, :func:`networkx.circular_layout`, :func:`networkx.random_layout`, and :func:`networkx.shell_layout`. Common kwargs include ``with_labels`` (``True``/``False``), ``labels`` (e.g. ``{agent: agent.id}``), ``node_color``, ``node_size``, and ``font_size``. Use ``plot_network=True`` to draw automatically at creation, or call ``net.plot(...)`` later with the same kwargs. Create problems using existing resources From 778dea6e934c27776423d799576253b019a718c8 Mon Sep 17 00:00:00 2001 From: Adriana Rodriguez Date: Thu, 25 Dec 2025 13:27:00 +0100 Subject: [PATCH 14/16] ref(networks): move plotting to util and update docs (#233) Remove Network.plot and plot flags from BenchmarkProblem, add network_utils.plot_network helper, and shift plotting guidance to a standalone doc section with intersphinx links. --- decent_bench/__init__.py | 2 + decent_bench/benchmark_problem.py | 11 --- decent_bench/network_utils.py | 68 ++++++++++++++++++ decent_bench/networks.py | 69 ++----------------- .../source/api/decent_bench.network_utils.rst | 7 ++ docs/source/api/decent_bench.rst | 1 + docs/source/user.rst | 28 +++++--- 7 files changed, 102 insertions(+), 84 deletions(-) create mode 100644 decent_bench/network_utils.py create mode 100644 docs/source/api/decent_bench.network_utils.rst diff --git a/decent_bench/__init__.py b/decent_bench/__init__.py index f589e70..3ed3bc6 100644 --- a/decent_bench/__init__.py +++ b/decent_bench/__init__.py @@ -7,6 +7,7 @@ datasets, distributed_algorithms, metrics, + network_utils, networks, schemes, ) @@ -20,6 +21,7 @@ "datasets", "distributed_algorithms", "metrics", + "network_utils", "networks", "schemes", ] diff --git a/decent_bench/benchmark_problem.py b/decent_bench/benchmark_problem.py index d372c75..ad7377a 100644 --- a/decent_bench/benchmark_problem.py +++ b/decent_bench/benchmark_problem.py @@ -45,8 +45,6 @@ class BenchmarkProblem: message_compression: message compression setting message_noise: message noise setting message_drop: message drops setting - plot_network: plot the network when it is created (optional) - plot_network_kwargs: kwargs forwarded to :meth:`decent_bench.networks.Network.plot` """ @@ -57,8 +55,6 @@ class BenchmarkProblem: message_compression: CompressionScheme message_noise: NoiseScheme message_drop: DropScheme - plot_network: bool = False - plot_network_kwargs: dict[str, Any] | None = None def create_regression_problem( @@ -70,8 +66,6 @@ def create_regression_problem( compression: bool = False, noise: bool = False, drops: bool = False, - plot_network: bool = False, - plot_network_kwargs: dict[str, Any] | None = None, ) -> BenchmarkProblem: """ Create out-of-the-box regression problems. @@ -84,9 +78,6 @@ def create_regression_problem( compression: if true, messages are rounded to 4 significant digits noise: if true, messages are distorted by Gaussian noise drops: if true, messages have a 50% probability of being dropped - plot_network: if true, plot the network when it is created - plot_network_kwargs: kwargs forwarded to :meth:`decent_bench.networks.Network.plot` when - ``plot_network`` is true """ network_structure = nx.random_regular_graph(n_neighbors_per_agent, n_agents, seed=0) @@ -114,6 +105,4 @@ def create_regression_problem( message_compression=message_compression, message_noise=message_noise, message_drop=message_drop, - plot_network=plot_network, - plot_network_kwargs=plot_network_kwargs, ) diff --git a/decent_bench/network_utils.py b/decent_bench/network_utils.py new file mode 100644 index 0000000..8476116 --- /dev/null +++ b/decent_bench/network_utils.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any, Literal + +import matplotlib.axes +import networkx +import networkx as nx + +_LAYOUT_FUNCS: dict[Literal["spring", "kamada_kawai", "circular", "random", "shell"], Any] = { + "spring": nx.drawing.layout.spring_layout, + "kamada_kawai": nx.drawing.layout.kamada_kawai_layout, + "circular": nx.drawing.layout.circular_layout, + "random": nx.drawing.layout.random_layout, + "shell": nx.drawing.layout.shell_layout, +} + + +def plot_network( + graph: networkx.Graph[Any], + *, + ax: matplotlib.axes.Axes | None = None, + layout: Literal["spring", "kamada_kawai", "circular", "random", "shell"] = "spring", + **draw_kwargs: Mapping[str, object], +) -> matplotlib.axes.Axes: + """ + Plot a NetworkX graph using the built-in NetworkX drawing utilities. + + Args: + graph: NetworkX graph to plot. + ax: optional :class:`matplotlib.axes.Axes` to draw on. If ``None`` a new figure is created. + layout: layout algorithm to position nodes (e.g. :func:`networkx.drawing.layout.spring_layout`, + :func:`networkx.drawing.layout.kamada_kawai_layout`, + :func:`networkx.drawing.layout.circular_layout`, + :func:`networkx.drawing.layout.random_layout`, + :func:`networkx.drawing.layout.shell_layout`). + draw_kwargs: forwarded to :func:`networkx.drawing.nx_pylab.draw_networkx`. + + Returns: + The matplotlib :class:`matplotlib.axes.Axes` containing the plot. + + Raises: + RuntimeError: if matplotlib is not available. + ValueError: if an unsupported layout is requested. + + """ + try: + import matplotlib.pyplot as plt # noqa: PLC0415 + except Exception as exc: # pragma: no cover - runtime dependency guard + raise RuntimeError("matplotlib is required for plotting the network") from exc + + layout_func = _LAYOUT_FUNCS.get(layout) + if layout_func is None: + supported = ", ".join(sorted(_LAYOUT_FUNCS)) + raise ValueError(f"Unsupported layout '{layout}'. Supported layouts: {supported}") + + pos = layout_func(graph) + if ax is None: + _, ax = plt.subplots() + + draw_kwargs_dict: dict[str, Any] = dict(draw_kwargs) + nx.drawing.nx_pylab.draw_networkx( + graph, + pos=pos, + ax=ax, + **draw_kwargs_dict, + ) + return ax diff --git a/decent_bench/networks.py b/decent_bench/networks.py index 1120bd0..66a2da3 100644 --- a/decent_bench/networks.py +++ b/decent_bench/networks.py @@ -3,11 +3,10 @@ from abc import ABC from collections.abc import Collection, Mapping, Sequence from functools import cached_property -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, cast import networkx as nx import numpy as np -from matplotlib.axes import Axes import decent_bench.utils.interoperability as iop from decent_bench.agents import Agent @@ -20,16 +19,6 @@ else: AgentGraph = nx.Graph -Layout = Literal["spring", "kamada_kawai", "circular", "random", "shell"] - -_LAYOUT_FUNCS: dict[Layout, Any] = { - "spring": nx.spring_layout, - "kamada_kawai": nx.kamada_kawai_layout, - "circular": nx.circular_layout, - "random": nx.random_layout, - "shell": nx.shell_layout, -} - class Network(ABC): # noqa: B024 """Base network object defining communication logic shared by all network types.""" @@ -83,48 +72,6 @@ def connected_agents(self, agent: Agent) -> list[Agent]: """Agents directly connected to ``agent`` in the underlying graph.""" return list(self.graph.neighbors(agent)) - def plot(self, ax: Axes | None = None, layout: Layout = "spring", **draw_kwargs: Mapping[str, object]) -> Axes: - """ - Plot the network using NetworkX drawing utilities. - - Args: - ax: optional matplotlib Axes to draw on. If ``None`` a new figure is created. - layout: layout algorithm to position nodes (e.g. :func:`networkx.spring_layout`, - :func:`networkx.kamada_kawai_layout`, :func:`networkx.circular_layout`, - :func:`networkx.random_layout`, :func:`networkx.shell_layout`). - draw_kwargs: forwarded to :func:`networkx.draw_networkx`. - - Returns: - The matplotlib Axes containing the plot. - - Raises: - RuntimeError: if matplotlib is not available. - ValueError: if an unsupported layout is requested. - - """ - try: - import matplotlib.pyplot as plt # noqa: PLC0415 - except Exception as exc: # pragma: no cover - runtime dependency guard - raise RuntimeError("matplotlib is required for plotting the network") from exc - - layout_func = _LAYOUT_FUNCS.get(layout) - if layout_func is None: - supported = ", ".join(sorted(_LAYOUT_FUNCS)) - raise ValueError(f"Unsupported layout '{layout}'. Supported layouts: {supported}") - - pos = layout_func(self.graph) - if ax is None: - _, ax = plt.subplots() - - draw_kwargs_dict: dict[str, Any] = dict(draw_kwargs) - nx.draw_networkx( - self.graph, - pos=pos, - ax=ax, - **draw_kwargs_dict, - ) - return ax - def _send_one(self, sender: Agent, receiver: Agent, msg: Array) -> None: """ Send message to an agent. @@ -302,7 +249,7 @@ def adjacency(self) -> Array: """ agents = self.agents() adjacency_matrix = nx.to_numpy_array( - cast("nx.Graph[Any]", self.graph), + self.graph, nodelist=cast("Collection[Any]", agents), dtype=float, ) # type: ignore[call-overload] @@ -551,16 +498,12 @@ def create_distributed_network(problem: BenchmarkProblem) -> P2PNetwork: agents = [Agent(i, problem.costs[i], problem.agent_activations[i]) for i in range(n_agents)] agent_node_map = {node: agents[i] for i, node in enumerate(problem.network_structure.nodes())} graph = nx.relabel_nodes(problem.network_structure, agent_node_map) - network = P2PNetwork( + return P2PNetwork( graph=graph, message_noise=problem.message_noise, message_compression=problem.message_compression, message_drop=problem.message_drop, ) - if getattr(problem, "plot_network", False): - plot_kwargs = getattr(problem, "plot_network_kwargs", None) or {} - network.plot(**plot_kwargs) - return network def create_federated_network(problem: BenchmarkProblem) -> FedNetwork: @@ -591,13 +534,9 @@ def create_federated_network(problem: BenchmarkProblem) -> FedNetwork: agents = [Agent(i, problem.costs[i], problem.agent_activations[i]) for i in range(n_agents)] agent_node_map = {node: agents[i] for i, node in enumerate(problem.network_structure.nodes())} graph = nx.relabel_nodes(problem.network_structure, agent_node_map) - network = FedNetwork( + return FedNetwork( graph=graph, message_noise=problem.message_noise, message_compression=problem.message_compression, message_drop=problem.message_drop, ) - if getattr(problem, "plot_network", False): - plot_kwargs = getattr(problem, "plot_network_kwargs", None) or {} - network.plot(**plot_kwargs) - return network diff --git a/docs/source/api/decent_bench.network_utils.rst b/docs/source/api/decent_bench.network_utils.rst new file mode 100644 index 0000000..6634ac2 --- /dev/null +++ b/docs/source/api/decent_bench.network_utils.rst @@ -0,0 +1,7 @@ +decent\_bench.network\_utils +============================ + +.. automodule:: decent_bench.network_utils + :members: + :show-inheritance: + :undoc-members: \ No newline at end of file diff --git a/docs/source/api/decent_bench.rst b/docs/source/api/decent_bench.rst index cc7e857..9c490f2 100644 --- a/docs/source/api/decent_bench.rst +++ b/docs/source/api/decent_bench.rst @@ -18,6 +18,7 @@ decent\_bench decent_bench.costs decent_bench.datasets decent_bench.distributed_algorithms + decent_bench.network_utils decent_bench.networks decent_bench.schemes diff --git a/docs/source/user.rst b/docs/source/user.rst index 3240b60..df598ca 100644 --- a/docs/source/user.rst +++ b/docs/source/user.rst @@ -94,9 +94,6 @@ Configure communication constraints and other settings for out-of-the-box regres compression=True, noise=True, drops=True, - # Optional: plot the network when it is created - plot_network=False, - plot_network_kwargs=None, ) if __name__ == "__main__": @@ -133,8 +130,6 @@ Change the settings of an already created benchmark problem, for example, the ne compression=True, noise=True, drops=True, - plot_network=True, - plot_network_kwargs={"layout": "circular", "with_labels": True}, ) problem.network_structure = nx.random_regular_graph(n_agents, n_neighbors_per_agent) @@ -149,10 +144,27 @@ Change the settings of an already created benchmark problem, for example, the ne benchmark_problem=problem, ) +Network utilities +----------------- +Plot a network explicitly when you need it: -Notes on plotting -~~~~~~~~~~~~~~~~~ -``plot_network_kwargs`` are passed directly to :func:`networkx.draw_networkx`. Supported ``layout`` values are :func:`networkx.spring_layout`, :func:`networkx.kamada_kawai_layout`, :func:`networkx.circular_layout`, :func:`networkx.random_layout`, and :func:`networkx.shell_layout`. Common kwargs include ``with_labels`` (``True``/``False``), ``labels`` (e.g. ``{agent: agent.id}``), ``node_color``, ``node_size``, and ``font_size``. Use ``plot_network=True`` to draw automatically at creation, or call ``net.plot(...)`` later with the same kwargs. +.. code-block:: python + + import networkx as nx + from decent_bench import benchmark, benchmark_problem, network_utils + from decent_bench.costs import LinearRegressionCost + from decent_bench.distributed_algorithms import ADMM, DGD, ED + + problem = benchmark_problem.create_regression_problem(LinearRegressionCost, n_agents=25, n_neighbors_per_agent=3) + + # Plot using decent-bench helper (wraps :func:`networkx.drawing.nx_pylab.draw_networkx`) + network_utils.plot_network(problem.network_structure, layout="circular", with_labels=True) + + # Or call NetworkX directly on the graph + pos = nx.drawing.layout.spring_layout(problem.network_structure) + nx.drawing.nx_pylab.draw_networkx(problem.network_structure, pos=pos, with_labels=True) + +For more options, see the `NetworkX drawing guide `_. Create problems using existing resources From da5386190aad1131e898a34994d1c69ab1dba123 Mon Sep 17 00:00:00 2001 From: Adriana Rodriguez Date: Fri, 26 Dec 2025 22:08:24 +0100 Subject: [PATCH 15/16] ref(networks): Simplify FedNetwork API and relocate utils (#233) Move network_utils into utils namespace. Simplify FedNetwork messaging helpers. Update API docs and user guide imports for the new utils path. --- decent_bench/__init__.py | 2 - decent_bench/networks.py | 81 +------------------ decent_bench/{ => utils}/network_utils.py | 0 .../source/api/decent_bench.network_utils.rst | 7 -- docs/source/api/decent_bench.rst | 2 +- .../api/decent_bench.utils.network_utils.rst | 7 ++ docs/source/api/decent_bench.utils.rst | 1 + docs/source/user.rst | 3 +- 8 files changed, 14 insertions(+), 89 deletions(-) rename decent_bench/{ => utils}/network_utils.py (100%) delete mode 100644 docs/source/api/decent_bench.network_utils.rst create mode 100644 docs/source/api/decent_bench.utils.network_utils.rst diff --git a/decent_bench/__init__.py b/decent_bench/__init__.py index 3ed3bc6..f589e70 100644 --- a/decent_bench/__init__.py +++ b/decent_bench/__init__.py @@ -7,7 +7,6 @@ datasets, distributed_algorithms, metrics, - network_utils, networks, schemes, ) @@ -21,7 +20,6 @@ "datasets", "distributed_algorithms", "metrics", - "network_utils", "networks", "schemes", ] diff --git a/decent_bench/networks.py b/decent_bench/networks.py index abe51d3..4e229ad 100644 --- a/decent_bench/networks.py +++ b/decent_bench/networks.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC -from collections.abc import Collection, Mapping, Sequence +from collections.abc import Collection, Sequence from functools import cached_property from typing import TYPE_CHECKING, Any, cast @@ -392,86 +392,11 @@ def receive(self, receiver: Agent, sender: Agent | Sequence[Agent] | None = None raise ValueError("All senders must be clients") super().receive(receiver=receiver, sender=sender) - def send_to_client(self, client: Agent, msg: Array) -> None: - """ - Send a message from the server to a specific client. - - Raises: - ValueError: if the receiver is not a client. - - """ - if client not in self.clients: - raise ValueError("Receiver must be a client") - self.send(sender=self.server, receiver=client, msg=msg) - - def send_to_all_clients(self, msg: Array) -> None: + def broadcast(self, msg: Array) -> None: """Send the same message from the server to every client (synchronous FL push).""" self.send(sender=self.server, receiver=None, msg=msg) - def send_from_client(self, client: Agent, msg: Array) -> None: - """ - Send a message from a client to the server. - - Raises: - ValueError: if the sender is not a client. - - """ - if client not in self.clients: - raise ValueError("Sender must be a client") - self.send(sender=client, receiver=self.server, msg=msg) - - def send_from_all_clients(self, msgs: Mapping[Agent, Array]) -> None: - """ - Send messages from each client to the server (synchronous FL push). - - Args: - msgs: mapping from client Agent to the message that client should send. Must include all clients. - - Raises: - ValueError: if any sender is not a client or if any client is missing. - - """ - clients = set(self.clients) - senders = set(msgs) - invalid = senders - clients - if invalid: - raise ValueError("All senders must be clients") - missing = clients - senders - if missing: - raise ValueError("Messages must be provided for all clients") - for client, msg in msgs.items(): - self.send_from_client(client, msg) - - def receive_at_client(self, client: Agent) -> None: - """ - Receive a message at a client from the server. - - Raises: - ValueError: if the receiver is not a client. - - """ - if client not in self.clients: - raise ValueError("Receiver must be a client") - self.receive(receiver=client, sender=None) - - def receive_at_all_clients(self) -> None: - """Receive messages at every client from the server (synchronous FL pull).""" - for client in self.clients: - self.receive_at_client(client) - - def receive_from_client(self, client: Agent) -> None: - """ - Receive a message at the server from a specific client. - - Raises: - ValueError: if the sender is not a client. - - """ - if client not in self.clients: - raise ValueError("Sender must be a client") - self.receive(receiver=self.server, sender=client) - - def receive_from_all_clients(self) -> None: + def receive_all(self) -> None: """Receive messages at the server from every client (synchronous FL pull).""" self.receive(receiver=self.server, sender=None) diff --git a/decent_bench/network_utils.py b/decent_bench/utils/network_utils.py similarity index 100% rename from decent_bench/network_utils.py rename to decent_bench/utils/network_utils.py diff --git a/docs/source/api/decent_bench.network_utils.rst b/docs/source/api/decent_bench.network_utils.rst deleted file mode 100644 index 6634ac2..0000000 --- a/docs/source/api/decent_bench.network_utils.rst +++ /dev/null @@ -1,7 +0,0 @@ -decent\_bench.network\_utils -============================ - -.. automodule:: decent_bench.network_utils - :members: - :show-inheritance: - :undoc-members: \ No newline at end of file diff --git a/docs/source/api/decent_bench.rst b/docs/source/api/decent_bench.rst index 9c490f2..87d938d 100644 --- a/docs/source/api/decent_bench.rst +++ b/docs/source/api/decent_bench.rst @@ -18,7 +18,7 @@ decent\_bench decent_bench.costs decent_bench.datasets decent_bench.distributed_algorithms - decent_bench.network_utils + decent_bench.utils.network_utils decent_bench.networks decent_bench.schemes diff --git a/docs/source/api/decent_bench.utils.network_utils.rst b/docs/source/api/decent_bench.utils.network_utils.rst new file mode 100644 index 0000000..dae7cfa --- /dev/null +++ b/docs/source/api/decent_bench.utils.network_utils.rst @@ -0,0 +1,7 @@ +decent_bench.utils.network_utils +================================ + +.. automodule:: decent_bench.utils.network_utils + :members: + :show-inheritance: + :undoc-members: diff --git a/docs/source/api/decent_bench.utils.rst b/docs/source/api/decent_bench.utils.rst index 9c5126e..b12e086 100644 --- a/docs/source/api/decent_bench.utils.rst +++ b/docs/source/api/decent_bench.utils.rst @@ -9,6 +9,7 @@ decent\_bench.utils decent_bench.utils.array decent_bench.utils.interoperability decent_bench.utils.logger + decent_bench.utils.network_utils decent_bench.utils.progress_bar decent_bench.utils.types diff --git a/docs/source/user.rst b/docs/source/user.rst index 7752728..b9f2d6f 100644 --- a/docs/source/user.rst +++ b/docs/source/user.rst @@ -157,7 +157,8 @@ Plot a network explicitly when you need it: .. code-block:: python import networkx as nx - from decent_bench import benchmark, benchmark_problem, network_utils + from decent_bench import benchmark, benchmark_problem + from decent_bench.utils import network_utils from decent_bench.costs import LinearRegressionCost from decent_bench.distributed_algorithms import ADMM, DGD, ED From 8875b5175bdc78071ae7435c5fd2532acfb3a9db Mon Sep 17 00:00:00 2001 From: Adriana Rodriguez Date: Mon, 29 Dec 2025 17:55:48 +0100 Subject: [PATCH 16/16] docs(user-guide): tidy network utilities section (#233) Remove unused imports in the network utilities example and reorder the section. --- docs/source/user.rst | 47 ++++++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/docs/source/user.rst b/docs/source/user.rst index b9f2d6f..4ccdb24 100644 --- a/docs/source/user.rst +++ b/docs/source/user.rst @@ -150,29 +150,6 @@ Change the settings of an already created benchmark problem, for example, the ne benchmark_problem=problem, ) -Network utilities ------------------ -Plot a network explicitly when you need it: - -.. code-block:: python - - import networkx as nx - from decent_bench import benchmark, benchmark_problem - from decent_bench.utils import network_utils - from decent_bench.costs import LinearRegressionCost - from decent_bench.distributed_algorithms import ADMM, DGD, ED - - problem = benchmark_problem.create_regression_problem(LinearRegressionCost, n_agents=25, n_neighbors_per_agent=3) - - # Plot using decent-bench helper (wraps :func:`networkx.drawing.nx_pylab.draw_networkx`) - network_utils.plot_network(problem.network_structure, layout="circular", with_labels=True) - - # Or call NetworkX directly on the graph - pos = nx.drawing.layout.spring_layout(problem.network_structure) - nx.drawing.nx_pylab.draw_networkx(problem.network_structure, pos=pos, with_labels=True) - -For more options, see the `NetworkX drawing guide `_. - Create problems using existing resources ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -284,6 +261,30 @@ corresponding abstracts. benchmark_problem=problem, ) + +Network utilities +----------------- +Plot a network explicitly when you need it: + +.. code-block:: python + + import networkx as nx + from decent_bench import benchmark_problem + from decent_bench.utils import network_utils + from decent_bench.costs import LinearRegressionCost + + problem = benchmark_problem.create_regression_problem(LinearRegressionCost, n_agents=25, n_neighbors_per_agent=3) + + # Plot using decent-bench helper (wraps :func:`networkx.drawing.nx_pylab.draw_networkx`) + network_utils.plot_network(problem.network_structure, layout="circular", with_labels=True) + + # Or call NetworkX directly on the graph + pos = nx.drawing.layout.spring_layout(problem.network_structure) + nx.drawing.nx_pylab.draw_networkx(problem.network_structure, pos=pos, with_labels=True) + +For more options, see the `NetworkX drawing guide `_. + + Interoperability requirement ---------------------------- Decent-Bench is designed to interoperate with multiple array/tensor frameworks (NumPy, PyTorch, JAX, etc.). To keep