Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
1380951
feat(networks): add network base and federated support
Dec 9, 2025
9a669db
wip(fl-setup): save network changes after merge (Array)
Dec 10, 2025
96a89e6
fix(networks): resolve lint and doc warning
Dec 12, 2025
8c22e64
merge(fl-setup): sync with main and verify suite
Dec 12, 2025
e6fce23
enh(networks): address review feedback (#229)
Dec 13, 2025
9b968d9
merge(fl-setup): sync with main and verify suite
Dec 17, 2025
6973a75
ref(networks): align APIs with PR feedback (#229)
Dec 18, 2025
38075db
enh(networks): address review feedback (#229)
Dec 19, 2025
0b2acca
ref(networks): Enforce connectivity checks in send/receive (#229)
Dec 19, 2025
2e59602
enh(networks): Alias neighbors to base connected_agents (#229)
Dec 19, 2025
3ff5876
feat(networks): expose networkx helpers and plotting
Dec 19, 2025
8efe189
enh(networks): Detail connection validation errors (#229)
Dec 20, 2025
e9880e2
fix(networks): align plotting typing and docs config
Dec 20, 2025
cff0081
merge(network): merge remote-tracking branch 'origin/feat/fl-setup' i…
Dec 20, 2025
d342f32
merge(network): Merge latest 'main' into feat/networkx-extras and run…
Dec 20, 2025
d581924
docs(user): Document network plot kwargs and layouts
Dec 20, 2025
e5dcad2
wip(networkx-extras): Address PR review feedback
Dec 24, 2025
778dea6
ref(networks): move plotting to util and update docs (#233)
Dec 25, 2025
2547010
Merge remote-tracking branch 'origin/main' into feat/networkx-extras
Dec 25, 2025
da53861
ref(networks): Simplify FedNetwork API and relocate utils (#233)
Dec 26, 2025
8875b51
docs(user-guide): tidy network utilities section (#233)
Dec 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 29 additions & 88 deletions decent_bench/networks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

from abc import ABC
from collections.abc import Mapping, Sequence
from collections.abc import Collection, Sequence
from functools import cached_property
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, cast

import networkx as nx
import numpy as np
Expand Down Expand Up @@ -47,6 +49,16 @@ def agents(self) -> list[Agent]:
"""Get all agents in the network."""
return list(self.graph)

@property
def degrees(self) -> dict[Agent, int]:
"""Degree of each agent in the network."""
return dict(self.graph.degree())

@property
def edges(self) -> list[tuple[Agent, Agent]]:
"""Edges of the network as (agent, agent) tuples."""
return list(self.graph.edges())

def active_agents(self, iteration: int) -> list[Agent]:
"""
Get all active agents.
Expand Down Expand Up @@ -215,11 +227,12 @@ def weights(self) -> Array:

n = len(agents)
W = np.zeros((n, n)) # noqa: N806
degrees = self.degrees
for i in agents:
neighbors = self.neighbors(i)
d_i = len(neighbors)
d_i = degrees[i]
for j in neighbors:
d_j = len(self.neighbors(j))
d_j = degrees[j]
W[i, j] = 1 / (1 + max(d_i, d_j))
for i in agents:
W[i, i] = 1 - sum(W[i])
Expand All @@ -235,13 +248,16 @@ def adjacency(self) -> Array:
Use ``adjacency[i, j]`` or ``adjacency[i.id, j.id]`` to get the adjacency between agent i and j.
"""
agents = self.agents()
n = len(agents)
A = np.zeros((n, n)) # noqa: N806
for i in agents:
for j in self.neighbors(i):
A[i, j] = 1

return iop.to_array(A, agents[0].cost.framework, agents[0].cost.device)
adjacency_matrix = nx.to_numpy_array(
self.graph,
nodelist=cast("Collection[Any]", agents),
dtype=float,
) # type: ignore[call-overload]
return iop.to_array(
adjacency_matrix,
agents[0].cost.framework,
agents[0].cost.device,
)

def neighbors(self, agent: Agent) -> list[Agent]:
"""Alias for :meth:`~decent_bench.networks.Network.connected_agents`."""
Expand Down Expand Up @@ -376,86 +392,11 @@ def receive(self, receiver: Agent, sender: Agent | Sequence[Agent] | None = None
raise ValueError("All senders must be clients")
super().receive(receiver=receiver, sender=sender)

def send_to_client(self, client: Agent, msg: Array) -> None:
"""
Send a message from the server to a specific client.

Raises:
ValueError: if the receiver is not a client.

"""
if client not in self.clients:
raise ValueError("Receiver must be a client")
self.send(sender=self.server, receiver=client, msg=msg)

def send_to_all_clients(self, msg: Array) -> None:
def broadcast(self, msg: Array) -> None:
"""Send the same message from the server to every client (synchronous FL push)."""
self.send(sender=self.server, receiver=None, msg=msg)

def send_from_client(self, client: Agent, msg: Array) -> None:
"""
Send a message from a client to the server.

Raises:
ValueError: if the sender is not a client.

"""
if client not in self.clients:
raise ValueError("Sender must be a client")
self.send(sender=client, receiver=self.server, msg=msg)

def send_from_all_clients(self, msgs: Mapping[Agent, Array]) -> None:
"""
Send messages from each client to the server (synchronous FL push).

Args:
msgs: mapping from client Agent to the message that client should send. Must include all clients.

Raises:
ValueError: if any sender is not a client or if any client is missing.

"""
clients = set(self.clients)
senders = set(msgs)
invalid = senders - clients
if invalid:
raise ValueError("All senders must be clients")
missing = clients - senders
if missing:
raise ValueError("Messages must be provided for all clients")
for client, msg in msgs.items():
self.send_from_client(client, msg)

def receive_at_client(self, client: Agent) -> None:
"""
Receive a message at a client from the server.

Raises:
ValueError: if the receiver is not a client.

"""
if client not in self.clients:
raise ValueError("Receiver must be a client")
self.receive(receiver=client, sender=None)

def receive_at_all_clients(self) -> None:
"""Receive messages at every client from the server (synchronous FL pull)."""
for client in self.clients:
self.receive_at_client(client)

def receive_from_client(self, client: Agent) -> None:
"""
Receive a message at the server from a specific client.

Raises:
ValueError: if the sender is not a client.

"""
if client not in self.clients:
raise ValueError("Sender must be a client")
self.receive(receiver=self.server, sender=client)

def receive_from_all_clients(self) -> None:
def receive_all(self) -> None:
"""Receive messages at the server from every client (synchronous FL pull)."""
self.receive(receiver=self.server, sender=None)

Expand Down
68 changes: 68 additions & 0 deletions decent_bench/utils/network_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from __future__ import annotations

from collections.abc import Mapping
from typing import Any, Literal

import matplotlib.axes
import networkx
import networkx as nx

_LAYOUT_FUNCS: dict[Literal["spring", "kamada_kawai", "circular", "random", "shell"], Any] = {
"spring": nx.drawing.layout.spring_layout,
"kamada_kawai": nx.drawing.layout.kamada_kawai_layout,
"circular": nx.drawing.layout.circular_layout,
"random": nx.drawing.layout.random_layout,
"shell": nx.drawing.layout.shell_layout,
}


def plot_network(
graph: networkx.Graph[Any],
*,
ax: matplotlib.axes.Axes | None = None,
layout: Literal["spring", "kamada_kawai", "circular", "random", "shell"] = "spring",
**draw_kwargs: Mapping[str, object],
) -> matplotlib.axes.Axes:
"""
Plot a NetworkX graph using the built-in NetworkX drawing utilities.

Args:
graph: NetworkX graph to plot.
ax: optional :class:`matplotlib.axes.Axes` to draw on. If ``None`` a new figure is created.
layout: layout algorithm to position nodes (e.g. :func:`networkx.drawing.layout.spring_layout`,
:func:`networkx.drawing.layout.kamada_kawai_layout`,
:func:`networkx.drawing.layout.circular_layout`,
:func:`networkx.drawing.layout.random_layout`,
:func:`networkx.drawing.layout.shell_layout`).
draw_kwargs: forwarded to :func:`networkx.drawing.nx_pylab.draw_networkx`.

Returns:
The matplotlib :class:`matplotlib.axes.Axes` containing the plot.

Raises:
RuntimeError: if matplotlib is not available.
ValueError: if an unsupported layout is requested.

"""
try:
import matplotlib.pyplot as plt # noqa: PLC0415
except Exception as exc: # pragma: no cover - runtime dependency guard
raise RuntimeError("matplotlib is required for plotting the network") from exc

layout_func = _LAYOUT_FUNCS.get(layout)
if layout_func is None:
supported = ", ".join(sorted(_LAYOUT_FUNCS))
raise ValueError(f"Unsupported layout '{layout}'. Supported layouts: {supported}")

pos = layout_func(graph)
if ax is None:
_, ax = plt.subplots()

draw_kwargs_dict: dict[str, Any] = dict(draw_kwargs)
nx.drawing.nx_pylab.draw_networkx(
graph,
pos=pos,
ax=ax,
**draw_kwargs_dict,
)
return ax
1 change: 1 addition & 0 deletions docs/source/api/decent_bench.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ decent\_bench
decent_bench.costs
decent_bench.datasets
decent_bench.distributed_algorithms
decent_bench.utils.network_utils
decent_bench.networks
decent_bench.schemes

Expand Down
7 changes: 7 additions & 0 deletions docs/source/api/decent_bench.utils.network_utils.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
decent_bench.utils.network_utils
================================

.. automodule:: decent_bench.utils.network_utils
:members:
:show-inheritance:
:undoc-members:
1 change: 1 addition & 0 deletions docs/source/api/decent_bench.utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ decent\_bench.utils
decent_bench.utils.array
decent_bench.utils.interoperability
decent_bench.utils.logger
decent_bench.utils.network_utils
decent_bench.utils.progress_bar
decent_bench.utils.types

Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"numpy": ("https://numpy.org/doc/stable/", None),
"python": ("https://docs.python.org/3", None),
"torch": ("https://pytorch.org/docs/stable/", None),
"matplotlib": ("https://matplotlib.org/stable/", None),
"tensorflow": (
"https://www.tensorflow.org/api_docs/python",
"https://github.com/GPflow/tensorflow-intersphinx/raw/master/tf2_py_objects.inv",
Expand Down
24 changes: 24 additions & 0 deletions docs/source/user.rst
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,30 @@ corresponding abstracts.
benchmark_problem=problem,
)


Network utilities
-----------------
Plot a network explicitly when you need it:

.. code-block:: python

import networkx as nx
from decent_bench import benchmark_problem
from decent_bench.utils import network_utils
from decent_bench.costs import LinearRegressionCost

problem = benchmark_problem.create_regression_problem(LinearRegressionCost, n_agents=25, n_neighbors_per_agent=3)

# Plot using decent-bench helper (wraps :func:`networkx.drawing.nx_pylab.draw_networkx`)
network_utils.plot_network(problem.network_structure, layout="circular", with_labels=True)

# Or call NetworkX directly on the graph
pos = nx.drawing.layout.spring_layout(problem.network_structure)
nx.drawing.nx_pylab.draw_networkx(problem.network_structure, pos=pos, with_labels=True)

For more options, see the `NetworkX drawing guide <https://networkx.org/documentation/stable/reference/drawing.html>`_.


Interoperability requirement
----------------------------
Decent-Bench is designed to interoperate with multiple array/tensor frameworks (NumPy, PyTorch, JAX, etc.). To keep
Expand Down