diff --git a/decent_bench/networks.py b/decent_bench/networks.py index 90fc7ca..4e229ad 100644 --- a/decent_bench/networks.py +++ b/decent_bench/networks.py @@ -1,7 +1,9 @@ +from __future__ import annotations + from abc import ABC -from collections.abc import Mapping, Sequence +from collections.abc import Collection, Sequence from functools import cached_property -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, cast import networkx as nx import numpy as np @@ -47,6 +49,16 @@ def agents(self) -> list[Agent]: """Get all agents in the network.""" return list(self.graph) + @property + def degrees(self) -> dict[Agent, int]: + """Degree of each agent in the network.""" + return dict(self.graph.degree()) + + @property + def edges(self) -> list[tuple[Agent, Agent]]: + """Edges of the network as (agent, agent) tuples.""" + return list(self.graph.edges()) + def active_agents(self, iteration: int) -> list[Agent]: """ Get all active agents. @@ -215,11 +227,12 @@ def weights(self) -> Array: n = len(agents) W = np.zeros((n, n)) # noqa: N806 + degrees = self.degrees for i in agents: neighbors = self.neighbors(i) - d_i = len(neighbors) + d_i = degrees[i] for j in neighbors: - d_j = len(self.neighbors(j)) + d_j = degrees[j] W[i, j] = 1 / (1 + max(d_i, d_j)) for i in agents: W[i, i] = 1 - sum(W[i]) @@ -235,13 +248,16 @@ def adjacency(self) -> Array: Use ``adjacency[i, j]`` or ``adjacency[i.id, j.id]`` to get the adjacency between agent i and j. """ agents = self.agents() - n = len(agents) - A = np.zeros((n, n)) # noqa: N806 - for i in agents: - for j in self.neighbors(i): - A[i, j] = 1 - - return iop.to_array(A, agents[0].cost.framework, agents[0].cost.device) + adjacency_matrix = nx.to_numpy_array( + self.graph, + nodelist=cast("Collection[Any]", agents), + dtype=float, + ) # type: ignore[call-overload] + return iop.to_array( + adjacency_matrix, + agents[0].cost.framework, + agents[0].cost.device, + ) def neighbors(self, agent: Agent) -> list[Agent]: """Alias for :meth:`~decent_bench.networks.Network.connected_agents`.""" @@ -376,86 +392,11 @@ def receive(self, receiver: Agent, sender: Agent | Sequence[Agent] | None = None raise ValueError("All senders must be clients") super().receive(receiver=receiver, sender=sender) - def send_to_client(self, client: Agent, msg: Array) -> None: - """ - Send a message from the server to a specific client. - - Raises: - ValueError: if the receiver is not a client. - - """ - if client not in self.clients: - raise ValueError("Receiver must be a client") - self.send(sender=self.server, receiver=client, msg=msg) - - def send_to_all_clients(self, msg: Array) -> None: + def broadcast(self, msg: Array) -> None: """Send the same message from the server to every client (synchronous FL push).""" self.send(sender=self.server, receiver=None, msg=msg) - def send_from_client(self, client: Agent, msg: Array) -> None: - """ - Send a message from a client to the server. - - Raises: - ValueError: if the sender is not a client. - - """ - if client not in self.clients: - raise ValueError("Sender must be a client") - self.send(sender=client, receiver=self.server, msg=msg) - - def send_from_all_clients(self, msgs: Mapping[Agent, Array]) -> None: - """ - Send messages from each client to the server (synchronous FL push). - - Args: - msgs: mapping from client Agent to the message that client should send. Must include all clients. - - Raises: - ValueError: if any sender is not a client or if any client is missing. - - """ - clients = set(self.clients) - senders = set(msgs) - invalid = senders - clients - if invalid: - raise ValueError("All senders must be clients") - missing = clients - senders - if missing: - raise ValueError("Messages must be provided for all clients") - for client, msg in msgs.items(): - self.send_from_client(client, msg) - - def receive_at_client(self, client: Agent) -> None: - """ - Receive a message at a client from the server. - - Raises: - ValueError: if the receiver is not a client. - - """ - if client not in self.clients: - raise ValueError("Receiver must be a client") - self.receive(receiver=client, sender=None) - - def receive_at_all_clients(self) -> None: - """Receive messages at every client from the server (synchronous FL pull).""" - for client in self.clients: - self.receive_at_client(client) - - def receive_from_client(self, client: Agent) -> None: - """ - Receive a message at the server from a specific client. - - Raises: - ValueError: if the sender is not a client. - - """ - if client not in self.clients: - raise ValueError("Sender must be a client") - self.receive(receiver=self.server, sender=client) - - def receive_from_all_clients(self) -> None: + def receive_all(self) -> None: """Receive messages at the server from every client (synchronous FL pull).""" self.receive(receiver=self.server, sender=None) diff --git a/decent_bench/utils/network_utils.py b/decent_bench/utils/network_utils.py new file mode 100644 index 0000000..8476116 --- /dev/null +++ b/decent_bench/utils/network_utils.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any, Literal + +import matplotlib.axes +import networkx +import networkx as nx + +_LAYOUT_FUNCS: dict[Literal["spring", "kamada_kawai", "circular", "random", "shell"], Any] = { + "spring": nx.drawing.layout.spring_layout, + "kamada_kawai": nx.drawing.layout.kamada_kawai_layout, + "circular": nx.drawing.layout.circular_layout, + "random": nx.drawing.layout.random_layout, + "shell": nx.drawing.layout.shell_layout, +} + + +def plot_network( + graph: networkx.Graph[Any], + *, + ax: matplotlib.axes.Axes | None = None, + layout: Literal["spring", "kamada_kawai", "circular", "random", "shell"] = "spring", + **draw_kwargs: Mapping[str, object], +) -> matplotlib.axes.Axes: + """ + Plot a NetworkX graph using the built-in NetworkX drawing utilities. + + Args: + graph: NetworkX graph to plot. + ax: optional :class:`matplotlib.axes.Axes` to draw on. If ``None`` a new figure is created. + layout: layout algorithm to position nodes (e.g. :func:`networkx.drawing.layout.spring_layout`, + :func:`networkx.drawing.layout.kamada_kawai_layout`, + :func:`networkx.drawing.layout.circular_layout`, + :func:`networkx.drawing.layout.random_layout`, + :func:`networkx.drawing.layout.shell_layout`). + draw_kwargs: forwarded to :func:`networkx.drawing.nx_pylab.draw_networkx`. + + Returns: + The matplotlib :class:`matplotlib.axes.Axes` containing the plot. + + Raises: + RuntimeError: if matplotlib is not available. + ValueError: if an unsupported layout is requested. + + """ + try: + import matplotlib.pyplot as plt # noqa: PLC0415 + except Exception as exc: # pragma: no cover - runtime dependency guard + raise RuntimeError("matplotlib is required for plotting the network") from exc + + layout_func = _LAYOUT_FUNCS.get(layout) + if layout_func is None: + supported = ", ".join(sorted(_LAYOUT_FUNCS)) + raise ValueError(f"Unsupported layout '{layout}'. Supported layouts: {supported}") + + pos = layout_func(graph) + if ax is None: + _, ax = plt.subplots() + + draw_kwargs_dict: dict[str, Any] = dict(draw_kwargs) + nx.drawing.nx_pylab.draw_networkx( + graph, + pos=pos, + ax=ax, + **draw_kwargs_dict, + ) + return ax diff --git a/docs/source/api/decent_bench.rst b/docs/source/api/decent_bench.rst index cc7e857..87d938d 100644 --- a/docs/source/api/decent_bench.rst +++ b/docs/source/api/decent_bench.rst @@ -18,6 +18,7 @@ decent\_bench decent_bench.costs decent_bench.datasets decent_bench.distributed_algorithms + decent_bench.utils.network_utils decent_bench.networks decent_bench.schemes diff --git a/docs/source/api/decent_bench.utils.network_utils.rst b/docs/source/api/decent_bench.utils.network_utils.rst new file mode 100644 index 0000000..dae7cfa --- /dev/null +++ b/docs/source/api/decent_bench.utils.network_utils.rst @@ -0,0 +1,7 @@ +decent_bench.utils.network_utils +================================ + +.. automodule:: decent_bench.utils.network_utils + :members: + :show-inheritance: + :undoc-members: diff --git a/docs/source/api/decent_bench.utils.rst b/docs/source/api/decent_bench.utils.rst index 9c5126e..b12e086 100644 --- a/docs/source/api/decent_bench.utils.rst +++ b/docs/source/api/decent_bench.utils.rst @@ -9,6 +9,7 @@ decent\_bench.utils decent_bench.utils.array decent_bench.utils.interoperability decent_bench.utils.logger + decent_bench.utils.network_utils decent_bench.utils.progress_bar decent_bench.utils.types diff --git a/docs/source/conf.py b/docs/source/conf.py index 13368d3..805bc62 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -49,6 +49,7 @@ "numpy": ("https://numpy.org/doc/stable/", None), "python": ("https://docs.python.org/3", None), "torch": ("https://pytorch.org/docs/stable/", None), + "matplotlib": ("https://matplotlib.org/stable/", None), "tensorflow": ( "https://www.tensorflow.org/api_docs/python", "https://github.com/GPflow/tensorflow-intersphinx/raw/master/tf2_py_objects.inv", diff --git a/docs/source/user.rst b/docs/source/user.rst index 67a6dde..4ccdb24 100644 --- a/docs/source/user.rst +++ b/docs/source/user.rst @@ -261,6 +261,30 @@ corresponding abstracts. benchmark_problem=problem, ) + +Network utilities +----------------- +Plot a network explicitly when you need it: + +.. code-block:: python + + import networkx as nx + from decent_bench import benchmark_problem + from decent_bench.utils import network_utils + from decent_bench.costs import LinearRegressionCost + + problem = benchmark_problem.create_regression_problem(LinearRegressionCost, n_agents=25, n_neighbors_per_agent=3) + + # Plot using decent-bench helper (wraps :func:`networkx.drawing.nx_pylab.draw_networkx`) + network_utils.plot_network(problem.network_structure, layout="circular", with_labels=True) + + # Or call NetworkX directly on the graph + pos = nx.drawing.layout.spring_layout(problem.network_structure) + nx.drawing.nx_pylab.draw_networkx(problem.network_structure, pos=pos, with_labels=True) + +For more options, see the `NetworkX drawing guide `_. + + Interoperability requirement ---------------------------- Decent-Bench is designed to interoperate with multiple array/tensor frameworks (NumPy, PyTorch, JAX, etc.). To keep