Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion plato/servers/fedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ async def _process_reports(self):
# Check if we should aggregate weights directly or use deltas
# Try strategy's aggregate_weights first, fall back to aggregate_deltas
strategy_weights = None
if hasattr(self.aggregation_strategy, "aggregate_weights"):
if self._should_prefer_weight_aggregation():
strategy_weights = await self.aggregation_strategy.aggregate_weights(
self.updates, baseline_weights, weights_received, self.context
)
Expand Down Expand Up @@ -285,6 +285,20 @@ async def _process_reports(self):
self.clients_processed()
self.callback_handler.call_event("on_clients_processed", self)

def _should_prefer_weight_aggregation(self) -> bool:
"""Return whether the strategy should use direct weight aggregation."""
strategy_cls = type(self.aggregation_strategy)
aggregate_weights_impl = getattr(strategy_cls, "aggregate_weights", None)
aggregate_deltas_impl = getattr(strategy_cls, "aggregate_deltas", None)

if aggregate_weights_impl is None:
return False

return not (
aggregate_weights_impl is FedAvgAggregationStrategy.aggregate_weights
and aggregate_deltas_impl is not FedAvgAggregationStrategy.aggregate_deltas
)

def clients_processed(self) -> None:
"""Additional work to be performed after client reports have been processed."""

Expand Down
82 changes: 82 additions & 0 deletions tests/servers/test_fedavg_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,85 @@ def test_fedavg_aggregation_skips_feature_payloads(temp_config):
)

assert aggregated is None


class DummyAlgorithm:
"""Minimal algorithm stub for server aggregation dispatch tests."""

def __init__(self, baseline):
self.current = {name: tensor.clone() for name, tensor in baseline.items()}

def extract_weights(self):
return {name: tensor.clone() for name, tensor in self.current.items()}

def compute_weight_deltas(self, baseline_weights, weights_list):
return [
{
name: weights[name] - baseline_weights[name]
for name in baseline_weights.keys()
}
for weights in weights_list
]

def update_weights(self, deltas):
self.current = {
name: self.current[name] + deltas[name] for name in self.current.keys()
}
return self.extract_weights()

def load_weights(self, weights):
self.current = {name: tensor.clone() for name, tensor in weights.items()}


class DeltaOnlyStrategy(FedAvgAggregationStrategy):
"""Strategy overriding only delta aggregation to exercise dispatch."""

def __init__(self):
super().__init__()
self.delta_calls = 0

async def aggregate_deltas(self, updates, deltas_received, context):
self.delta_calls += 1
return await super().aggregate_deltas(updates, deltas_received, context)


def test_fedavg_server_prefers_custom_delta_strategy_over_inherited_weights(
temp_config,
):
"""Custom delta strategies should not be bypassed by inherited weight hooks."""
from plato.config import Config
from plato.servers import fedavg

Config().server.do_test = False

strategy = DeltaOnlyStrategy()
server = fedavg.Server(aggregation_strategy=strategy)

baseline = {"weight": torch.zeros((1, 2)), "bias": torch.zeros(1)}
server.algorithm = DummyAlgorithm(baseline)
server.context.algorithm = server.algorithm
server.context.server = server
server.context.state["prng_state"] = None

server.updates = [
SimpleNamespace(
client_id=1,
report=SimpleNamespace(
num_samples=1,
accuracy=0.5,
processing_time=0.1,
comm_time=0.1,
training_time=0.1,
),
payload={
"weight": torch.ones((1, 2)),
"bias": torch.ones(1),
},
)
]

asyncio.run(server._process_reports())

assert strategy.delta_calls == 1
assert torch.allclose(server.algorithm.current["weight"], torch.ones((1, 2)))
assert torch.allclose(server.algorithm.current["bias"], torch.ones(1))