Skip to content

Conversation

@XinghaoWu
Copy link
Collaborator

  1. Add FedAMP algo.
  2. Add the FedCAC paper link in README.md.
  3. Remove the default momentum in default_config.py to fix the bug when setting the optimizer to Adam.

2. Add FedCAC paper link in README.md.
3. Remove the default momentum in default_config.py to fix the bug when set optmizer to adam.
@XinghaoWu XinghaoWu requested a review from kxzxvbk October 25, 2023 15:43
@kxzxvbk kxzxvbk changed the title Add FedAMP algo and fix bugs. Feature(wxh): Add FedAMP algo and fix bugs. Oct 26, 2023
@kxzxvbk kxzxvbk mentioned this pull request Oct 26, 2023
20 tasks
@kxzxvbk kxzxvbk added the algorithm add new algorithm label Oct 26, 2023
class FedAMPClient(BaseClient):
"""
Overview:
This class is the base implementation of client in 'Bold but Cautious: Unlocking the Potential of Personalized
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct this document, not FedCAC.


def __init__(self, args, client_id, train_dataset, test_dataset=None):
"""
Initializing train dataset, test dataset(for personalized settings).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct this document, the purpose is to get a copy of local model.

super(FedAMPClient, self).__init__(args, client_id, train_dataset, test_dataset)
self.client_u = copy.deepcopy(self.model)

def FedAMP_Loss_client(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use the get_model_difference function (defined in fling/utils/torch_utils.py ) for simplification.

from fling.utils.utils import weight_flatten

@CLIENT_REGISTRY.register('fedamp_client')
class FedAMPClient(BaseClient):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering whether this client is identical to FedProxClient? What's the differences?

coef = torch.zeros(self.args.client.client_num)
for j, mw in enumerate(self.client_ws):
if i == j: continue
sub = weights[i] - weights[j]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rewrite it using fling.utils.get_model_difference

)
return participated_clients

def weight_flatten(model) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose that this function can be removed.

@kxzxvbk
Copy link
Collaborator

kxzxvbk commented Oct 26, 2023

Reformat the code before final merge.

@kxzxvbk
Copy link
Collaborator

kxzxvbk commented Oct 26, 2023

Add example configs for cifar100, mnist and tiny-imagenet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

algorithm add new algorithm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants