Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ target/
profile_default/
ipython_config.py

# Node
/src/node_modules/

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
Expand Down
2 changes: 1 addition & 1 deletion src/algos/MetaL2C.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(
) -> None:
super().__init__(config, comm_utils)

self.encoder = ModelEncoder(self.get_model_weights())
self.encoder = ModelEncoder(self.get_model_weights(get_external_repr=False))
self.encoder_optim = optim.SGD(
self.encoder.parameters(), lr=self.config["alpha_lr"]
)
Expand Down
31 changes: 28 additions & 3 deletions src/algos/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@
TransformDataset,
CorruptDataset,
)

# import the possible attacks
from algos.attack_add_noise import AddNoiseAttack
from algos.attack_bad_weights import BadWeightsAttack
from algos.attack_sign_flip import SignFlipAttack

from utils.log_utils import LogUtils
from utils.model_utils import ModelUtils
from utils.community_utils import (
Expand Down Expand Up @@ -94,6 +100,7 @@ class BaseNode(ABC):
def __init__(
self, config: Dict[str, Any], comm_utils: CommunicationManager
) -> None:
self.config = config
self.set_constants()
self.config = config
self.comm_utils = comm_utils
Expand Down Expand Up @@ -123,6 +130,7 @@ def __init__(

if "gia" in config and self.node_id in config["gia_attackers"]:
self.gia_attacker = True
self.malicious_type = config.get("malicious_type", "normal")

self.log_memory = config.get("log_memory", False)

Expand Down Expand Up @@ -170,7 +178,7 @@ def setup_logging(self, config: ConfigType) -> None:
def setup_cuda(self, config: ConfigType) -> None:
"""add docstring here"""
# Need a mapping from rank to device id
if (config.get("assign_based_on_host", False)):
if (config.get("assign_based_on_host", False)) == False:
device_ids_map = config["device_ids"]
node_name = f"node_{self.node_id}"
self.device_ids = device_ids_map[node_name]
Expand Down Expand Up @@ -266,7 +274,7 @@ def set_shared_exp_parameters(self, config: ConfigType) -> None:
def local_round_done(self) -> None:
self.round += 1

def get_model_weights(self, chop_model:bool=False) -> Dict[str, int|Dict[str, Any]]:
def get_model_weights(self, chop_model:bool=False, get_external_repr:bool=True) -> Dict[str, int|Dict[str, Any]]:
"""
Share the model weights
params:
Expand All @@ -275,6 +283,9 @@ def get_model_weights(self, chop_model:bool=False) -> Dict[str, int|Dict[str, An
if chop_model:
model, _ = self.model_utils.get_split_model(self.model, self.config["split_layer"])
model = model.state_dict()
elif get_external_repr and self.malicious_type != "normal":
# Get the external representation of the malicious model
model = self.get_malicious_model_weights()
else:
model = self.model.state_dict()
message: Dict[str, int|Dict[str, Any]] = {"sender": self.node_id, "round": self.round, "model": model}
Expand All @@ -290,6 +301,20 @@ def get_model_weights(self, chop_model:bool=False) -> Dict[str, int|Dict[str, An
message["model"][key] = message["model"][key].to("cpu")

return message

def get_malicious_model_weights(self) -> Dict[str, Tensor]:
"""
Get the external representation of the model based on the malicious type.
"""
if self.malicious_type == "sign_flip":
return SignFlipAttack(self.config, self.model.state_dict()).get_representation()
elif self.malicious_type == "bad_weights":
# print("bad weights attack")
return BadWeightsAttack(self.config, self.model.state_dict()).get_representation()
elif self.malicious_type == "add_noise":
return AddNoiseAttack(self.config, self.model.state_dict()).get_representation()
else:
return self.model.state_dict()

def get_local_rounds(self) -> int:
return self.round
Expand Down Expand Up @@ -1104,7 +1129,7 @@ def receive_and_aggregate_streaming(self, neighbors: List[int]) -> None:
total_weight = 0.0 # To re-normalize weights after handling dropouts

# Include the current node's model in the aggregation
current_model_wts = self.get_model_weights()
current_model_wts = self.get_model_weights(get_external_repr=False) # internal model representation
assert "model" in current_model_wts, "Model not found in the current model."
current_model_wts = current_model_wts["model"]
current_weight = 1.0 / (len(neighbors) + 1) # Weight for the current node
Expand Down
90 changes: 45 additions & 45 deletions src/algos/fl.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,52 +37,52 @@ def local_test(self, **kwargs: Any) -> Tuple[float, float, float]:
return test_loss, test_acc, time_taken


def get_model_weights(self, **kwargs: Any) -> Dict[str, Any]:
"""
Overwrite the get_model_weights method of the BaseNode
to add malicious attacks
TODO: this should be moved to BaseClient
"""

message = {"sender": self.node_id, "round": self.round}

malicious_type = self.config.get("malicious_type", "normal")

if malicious_type == "normal":
message["model"] = self.model.state_dict() # type: ignore
elif malicious_type == "bad_weights":
# Corrupt the weights
message["model"] = BadWeightsAttack(
self.config, self.model.state_dict()
).get_representation()
elif malicious_type == "sign_flip":
# Flip the sign of the weights, also TODO: consider label flipping
message["model"] = SignFlipAttack(
self.config, self.model.state_dict()
).get_representation()
elif malicious_type == "add_noise":
# Add noise to the weights
message["model"] = AddNoiseAttack(
self.config, self.model.state_dict()
).get_representation()
else:
message["model"] = self.model.state_dict() # type: ignore

# move the model to cpu before sending
for key in message["model"].keys():
message["model"][key] = message["model"][key].to("cpu")

# assert hasattr(self, 'images') and hasattr(self, 'labels'), "Images and labels not found"
if "gia" in self.config and hasattr(self, 'images') and hasattr(self, 'labels'):
# also stream image and labels
message["images"] = self.images.to("cpu")
message["labels"] = self.labels.to("cpu")

message["random_params"] = self.random_params
for key in message["random_params"].keys():
message["random_params"][key] = message["random_params"][key].to("cpu")
# def get_model_weights(self, **kwargs: Any) -> Dict[str, Any]:
# """
# Overwrite the get_model_weights method of the BaseNode
# to add malicious attacks
# TODO: this should be moved to BaseClient
# """

# message = {"sender": self.node_id, "round": self.round}

# malicious_type = self.config.get("malicious_type", "normal")

# if malicious_type == "normal":
# message["model"] = self.model.state_dict() # type: ignore
# elif malicious_type == "bad_weights":
# # Corrupt the weights
# message["model"] = BadWeightsAttack(
# self.config, self.model.state_dict()
# ).get_representation()
# elif malicious_type == "sign_flip":
# # Flip the sign of the weights, also TODO: consider label flipping
# message["model"] = SignFlipAttack(
# self.config, self.model.state_dict()
# ).get_representation()
# elif malicious_type == "add_noise":
# # Add noise to the weights
# message["model"] = AddNoiseAttack(
# self.config, self.model.state_dict()
# ).get_representation()
# else:
# message["model"] = self.model.state_dict() # type: ignore

# # move the model to cpu before sending
# for key in message["model"].keys():
# message["model"][key] = message["model"][key].to("cpu")

# # assert hasattr(self, 'images') and hasattr(self, 'labels'), "Images and labels not found"
# if "gia" in self.config and hasattr(self, 'images') and hasattr(self, 'labels'):
# # also stream image and labels
# message["images"] = self.images.to("cpu")
# message["labels"] = self.labels.to("cpu")

# message["random_params"] = self.random_params
# for key in message["random_params"].keys():
# message["random_params"][key] = message["random_params"][key].to("cpu")

return message # type: ignore
# return message # type: ignore

def run_protocol(self):
print(f"Client {self.node_id} ready to start training")
Expand Down
12 changes: 6 additions & 6 deletions src/configs/algo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,13 +204,12 @@ def get_malicious_types(malicious_config_list: List[ConfigType]) -> Dict[str, st
# Collaboration setup
"algo": "fedstatic",
"topology": {"name": "watts_strogatz", "k": 3, "p": 0.2}, # type: ignore
# "topology": {"name": "base_graph", "max_degree": 2}, # type: ignore
"rounds": 3,
"rounds": 200,
# Model parameters
"optimizer": "sgd", # TODO comment out for real training
"model": "resnet10",
"model_lr": 3e-4,
"batch_size": 256,
"model_lr": 0.1, # 3e-4,
"batch_size": 64,
}

swift: ConfigType = {
Expand All @@ -231,11 +230,12 @@ def get_malicious_types(malicious_config_list: List[ConfigType]) -> Dict[str, st
# comparison describes the metric or algorithm used to compare the weights of the models
# sampling describes the method used to sample the neighbors after the comparison
"topology": {"comparison": "weights_l2", "sampling": "closest"}, # type: ignore
"rounds": 20,
"rounds": 200,

# Model parameters
"optimizer": "sgd",
"model": "resnet10",
"model_lr": 3e-4,
"model_lr": 0.1,
"batch_size": 256,
}

Expand Down
8 changes: 5 additions & 3 deletions src/configs/algo_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
fedstatic: ConfigType = {
# Collaboration setup
"algo": "fedstatic",
"topology": {"name": "watts_strogatz", "k": 3, "p": 0.2}, # type: ignore
"rounds": 1,
# "topology": {"name": "watts_strogatz", "k": 3, "p": 0.2}, # type: ignore
"topology": {"name": "ring"},
"rounds": 200,

# Model parameters
"optimizer": "sgd",
"model": "resnet10",
"model_lr": 3e-4,
"model_lr": 0.1,
"batch_size": 256,
}

Expand Down
4 changes: 3 additions & 1 deletion src/configs/malicious_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Malicious Configuration
from utils.types import ConfigType
from typing import Dict
import random

# Weight Update Attacks
sign_flip: ConfigType = {
Expand Down Expand Up @@ -49,7 +50,7 @@
label_flip: ConfigType = {
"malicious_type": "label_flip",
"permute_labels": 10,
# "permutation": random.shuffle([i for i in range(10)]),
"permutation": random.sample(range(10), 10) # Generates a random permutation of labels 0-9
}

# List of Malicious node configurations
Expand All @@ -60,4 +61,5 @@
"gradient_attack": gradient_attack,
"backdoor_attack": backdoor_attack,
"data_poisoning": data_poisoning,
"label_flip": label_flip,
}
20 changes: 15 additions & 5 deletions src/configs/sys_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE):
CIFAR10_DSET = "cifar10"
CIAR10_DPATH = "./datasets/imgs/cifar10/"

NUM_COLLABORATORS = 3
DUMP_DIR = "/tmp/new_sonar/"
NUM_COLLABORATORS = 1
DUMP_DIR = "/mas/camera/Experiments/SONAR/jyuan/_tmp/"

num_users = 9
mpi_system_config: ConfigType = {
Expand Down Expand Up @@ -327,6 +327,7 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE):
"dropout_correlation": 0.0, # correlation between dropouts of successive rounds: [0,1]
}

dropout_dict = {} #empty dict to disable dropout
dropout_dicts: Any = {"node_0": {}}
for i in range(1, num_users + 1):
dropout_dicts[f"node_{i}"] = dropout_dict
Expand All @@ -346,14 +347,23 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE):
"device_ids": get_device_ids(num_users, gpu_ids),
"assign_based_on_host": True,
# "algos": get_algo_configs(num_users=num_users, algo_configs=default_config_list), # type: ignore
"algos": get_algo_configs(num_users=num_users, algo_configs=[fed_dynamic_weights]), # type: ignore
"samples_per_user": 500, # distributed equally
"algos": get_algo_configs(num_users=num_users, algo_configs=[fedstatic]), # type: ignore
"samples_per_user": 50000 // num_users, # distributed equally
"train_label_distribution": "non_iid",
"alpha_data": 0.1,
"test_label_distribution": "iid",
"exp_keys": [],
"dropout_dicts": dropout_dicts,
"log_memory": False,
"test_samples_per_user": 200,
"log_memory": True,
"streaming_aggregation": True, # Make it true for fedstatic
# "assign_based_on_host": True,
# "hostname_to_device_ids": {
# "matlaber1": [2, 3, 4, 5, 6, 7],
# "matlaber12": [0, 1, 2, 3],
# "matlaber3": [0, 1, 2, 3],
# "matlaber4": [0, 2, 3, 4, 5, 6, 7],
# }
}

grpc_system_config_gia: ConfigType = {
Expand Down
26 changes: 14 additions & 12 deletions src/configs/sys_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ def get_algo_configs(
CIFAR10_DSET = "cifar10"
CIAR10_DPATH = "./datasets/imgs/cifar10/"

DUMP_DIR = "/tmp/"
DUMP_DIR = "/mas/camera/Experiments/SONAR/jyuan/_tmp/"

NUM_COLLABORATORS = 1
num_users = 4
NUM_COLLABORATORS = 36
num_users = 36

dropout_dict = {
"distribution_dict": { # leave dict empty to disable dropout
Expand All @@ -101,34 +101,36 @@ def get_algo_configs(

gpu_ids = [2, 3, 5, 6]

topo = "torus"
algo_name = "no_malicious"
num_collaborators = NUM_COLLABORATORS

grpc_system_config: ConfigType = {
"exp_id": "static",
"exp_id": f"topo_{topo}x{algo_name}_{0}_malicious_{num_collaborators}_colab_3_4",
"num_users": num_users,
"num_collaborators": NUM_COLLABORATORS,
"comm": {"type": "GRPC", "synchronous": True, "peer_ids": ["localhost:50048"]}, # The super-node
"comm": {"type": "GRPC", "synchronous": True, "peer_ids": ["matlaber1.media.mit.edu:1112"]}, # The super-node
"dset": CIFAR10_DSET,
"dump_dir": DUMP_DIR,
"dpath": CIAR10_DPATH,
"seed": 2,
"device_ids": get_device_ids(num_users, gpu_ids),
# "algos": get_algo_configs(num_users=num_users, algo_configs=default_config_list), # type: ignore
"algos": get_algo_configs(num_users=num_users, algo_configs=[fedstatic]), # type: ignore
# "samples_per_user": 50000 // num_users, # distributed equally
"samples_per_user": 100,
"samples_per_user": 50000 // num_users, # distributed equally
"train_label_distribution": "non_iid",
"test_label_distribution": "iid",
"alpha_data": 1.0,
"exp_keys": [],
"dropout_dicts": dropout_dicts,
"test_samples_per_user": 200,
"log_memory": True,
# "streaming_aggregation": True, # Make it true for fedstatic
"streaming_aggregation": True, # Make it true for fedstatic
"assign_based_on_host": True,
"hostname_to_device_ids": {
"matlaber1": [2, 3, 4, 5, 6, 7],
"matlaber12": [0, 1, 2, 3],
"matlaber3": [0, 1, 2, 3],
"matlaber4": [0, 2, 3, 4, 5, 6, 7],
"matlaber1": [2, 3, 4, 6, 7],
"matlaber5": [1, 2],
"matlaber12": [2, 3],
}
}
current_config = grpc_system_config
4 changes: 2 additions & 2 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
logging.basicConfig(level=logging.DEBUG) # Enable detailed logging

# Default config file paths
B_DEFAULT: str = "./configs/algo_config.py"
S_DEFAULT: str = "./configs/sys_config.py"
B_DEFAULT: str = "./configs/algo_config_test.py"
S_DEFAULT: str = "./configs/sys_config_test.py"

# Parse args
parser : argparse.ArgumentParser = argparse.ArgumentParser(description="Run collaborative learning experiments")
Expand Down
Loading
Loading