From 9385b5f67add502a9dcdbaf35fd61382dcf88f30 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Tue, 10 Mar 2026 15:04:35 -0400 Subject: [PATCH] Fix FedAvg strategy dispatch for delta-only overrides. The composable FedAvg server always preferred aggregate_weights() when the aggregation strategy exposed that method. This broke strategies inheriting FedAvgAggregationStrategy that override only aggregate_deltas(), because the inherited FedAvg weight aggregation path ran first and bypassed the custom delta logic entirely. Add a dispatch guard in plato/servers/fedavg.py so the server does not take the inherited FedAvg aggregate_weights() fast path when the strategy class has customized aggregate_deltas() instead. This restores the intended execution path for Polaris and other similar strategies using the strategy-based server design. Add a regression test covering Server._process_reports() with a delta-only FedAvg strategy subclass to ensure custom aggregate_deltas() hooks are not shadowed by inherited weight aggregation. --- plato/servers/fedavg.py | 16 +++++- tests/servers/test_fedavg_strategy.py | 82 +++++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 1 deletion(-) diff --git a/plato/servers/fedavg.py b/plato/servers/fedavg.py index 36fee944b..e3418eba6 100644 --- a/plato/servers/fedavg.py +++ b/plato/servers/fedavg.py @@ -191,7 +191,7 @@ async def _process_reports(self): # Check if we should aggregate weights directly or use deltas # Try strategy's aggregate_weights first, fall back to aggregate_deltas strategy_weights = None - if hasattr(self.aggregation_strategy, "aggregate_weights"): + if self._should_prefer_weight_aggregation(): strategy_weights = await self.aggregation_strategy.aggregate_weights( self.updates, baseline_weights, weights_received, self.context ) @@ -285,6 +285,20 @@ async def _process_reports(self): self.clients_processed() self.callback_handler.call_event("on_clients_processed", self) + def _should_prefer_weight_aggregation(self) -> bool: + """Return whether the strategy should use direct weight aggregation.""" + strategy_cls = type(self.aggregation_strategy) + aggregate_weights_impl = getattr(strategy_cls, "aggregate_weights", None) + aggregate_deltas_impl = getattr(strategy_cls, "aggregate_deltas", None) + + if aggregate_weights_impl is None: + return False + + return not ( + aggregate_weights_impl is FedAvgAggregationStrategy.aggregate_weights + and aggregate_deltas_impl is not FedAvgAggregationStrategy.aggregate_deltas + ) + def clients_processed(self) -> None: """Additional work to be performed after client reports have been processed.""" diff --git a/tests/servers/test_fedavg_strategy.py b/tests/servers/test_fedavg_strategy.py index f647ff7d9..2f5999964 100644 --- a/tests/servers/test_fedavg_strategy.py +++ b/tests/servers/test_fedavg_strategy.py @@ -59,3 +59,85 @@ def test_fedavg_aggregation_skips_feature_payloads(temp_config): ) assert aggregated is None + + +class DummyAlgorithm: + """Minimal algorithm stub for server aggregation dispatch tests.""" + + def __init__(self, baseline): + self.current = {name: tensor.clone() for name, tensor in baseline.items()} + + def extract_weights(self): + return {name: tensor.clone() for name, tensor in self.current.items()} + + def compute_weight_deltas(self, baseline_weights, weights_list): + return [ + { + name: weights[name] - baseline_weights[name] + for name in baseline_weights.keys() + } + for weights in weights_list + ] + + def update_weights(self, deltas): + self.current = { + name: self.current[name] + deltas[name] for name in self.current.keys() + } + return self.extract_weights() + + def load_weights(self, weights): + self.current = {name: tensor.clone() for name, tensor in weights.items()} + + +class DeltaOnlyStrategy(FedAvgAggregationStrategy): + """Strategy overriding only delta aggregation to exercise dispatch.""" + + def __init__(self): + super().__init__() + self.delta_calls = 0 + + async def aggregate_deltas(self, updates, deltas_received, context): + self.delta_calls += 1 + return await super().aggregate_deltas(updates, deltas_received, context) + + +def test_fedavg_server_prefers_custom_delta_strategy_over_inherited_weights( + temp_config, +): + """Custom delta strategies should not be bypassed by inherited weight hooks.""" + from plato.config import Config + from plato.servers import fedavg + + Config().server.do_test = False + + strategy = DeltaOnlyStrategy() + server = fedavg.Server(aggregation_strategy=strategy) + + baseline = {"weight": torch.zeros((1, 2)), "bias": torch.zeros(1)} + server.algorithm = DummyAlgorithm(baseline) + server.context.algorithm = server.algorithm + server.context.server = server + server.context.state["prng_state"] = None + + server.updates = [ + SimpleNamespace( + client_id=1, + report=SimpleNamespace( + num_samples=1, + accuracy=0.5, + processing_time=0.1, + comm_time=0.1, + training_time=0.1, + ), + payload={ + "weight": torch.ones((1, 2)), + "bias": torch.ones(1), + }, + ) + ] + + asyncio.run(server._process_reports()) + + assert strategy.delta_calls == 1 + assert torch.allclose(server.algorithm.current["weight"], torch.ones((1, 2))) + assert torch.allclose(server.algorithm.current["bias"], torch.ones(1))