2323
2424from fedgraph .utils_nc import label_dirichlet_partition
2525
26- # DATASETS = ['cora', 'citeseer', 'pubmed']
27- DATASETS = ["pubmed" ]
28-
29- IID_BETAS = [10000.0 , 100.0 , 10.0 ]
30- CLIENT_NUM = 10
31- TOTAL_ROUNDS = 200
32- LOCAL_STEPS = 1
33- LEARNING_RATE = 0.1
34- HIDDEN_DIM = 64
26+ DATASETS = ["cora" ]
27+
28+ IID_BETAS = [100.0 ]
29+ CLIENT_NUM = 5
30+ TOTAL_ROUNDS = 100
31+ LOCAL_STEPS = 3
32+ LEARNING_RATE = 0.5
33+ HIDDEN_DIM = 16
3534DROPOUT_RATE = 0.5
36- CPUS_PER_TRAINER = 0.6
35+ CPUS_PER_TRAINER = 1
3736STANDALONE_PROCESSES = 1
3837
3938PLANETOID_NAMES = {"cora" : "Cora" , "citeseer" : "CiteSeer" , "pubmed" : "PubMed" }
@@ -68,32 +67,48 @@ def load_data(config, client_cfgs=None):
6867 ds = Planetoid (root = "data/" , name = PLANETOID_NAMES [name ])
6968 full = ds [0 ]
7069 num_classes = int (full .y .max ().item ()) + 1
71- # Dirichlet partition across all nodes
70+
71+ # 与data_process.py完全一致:在全部节点上做Dirichlet分割
7272 split_idxs = label_dirichlet_partition (
73- full .y ,
74- full .num_nodes ,
73+ full .y , # 使用全部节点,不只是训练节点
74+ len ( full .y ), # 全部节点数
7575 num_classes ,
7676 config .federate .client_num ,
7777 config .iid_beta ,
7878 config .distribution_type ,
7979 )
80+
8081 parts = []
8182 for idxs in split_idxs :
82- mask = torch .zeros (full .num_nodes , dtype = torch .bool )
83- mask [idxs ] = True
83+ client_nodes = torch .tensor (idxs )
84+
85+ # 为每个客户端创建mask,但保持原有的train/val/test划分逻辑
86+ train_mask = torch .zeros (full .num_nodes , dtype = torch .bool )
87+ val_mask = torch .zeros (full .num_nodes , dtype = torch .bool )
88+ test_mask = torch .zeros (full .num_nodes , dtype = torch .bool )
89+
90+ # 在客户端节点中,保持原有数据集的train/val/test划分
91+ for node in client_nodes :
92+ if full .train_mask [node ]:
93+ train_mask [node ] = True
94+ elif full .val_mask [node ]:
95+ val_mask [node ] = True
96+ elif full .test_mask [node ]:
97+ test_mask [node ] = True
98+
8499 parts .append (
85100 Data (
86101 x = full .x ,
87102 edge_index = full .edge_index ,
88103 y = full .y ,
89- train_mask = mask ,
90- val_mask = mask ,
91- test_mask = mask ,
104+ train_mask = train_mask , # 保持原有train划分
105+ val_mask = val_mask , # 保持原有val划分
106+ test_mask = test_mask , # 保持原有test划分
92107 )
93108 )
109+
94110 data_dict = {
95- i
96- + 1 : {
111+ i + 1 : {
97112 "data" : parts [i ],
98113 "train" : [parts [i ]],
99114 "val" : [parts [i ]],
@@ -124,66 +139,6 @@ def build(cfg_model, input_shape):
124139register_model (mkey , builder )
125140
126141
127- def run_fedavg_manual (ds , beta , rounds , clients ):
128- device = torch .device ("cpu" )
129- ds_obj = Planetoid (root = "data/" , name = PLANETOID_NAMES [ds ])
130- data = ds_obj [0 ].to (device )
131- in_channels = data .x .size (1 )
132- num_classes = int (data .y .max ().item ()) + 1
133- train_idx = data .train_mask .nonzero (as_tuple = False ).view (- 1 )
134- # Dirichlet partition over all nodes
135- split_idxs = label_dirichlet_partition (
136- data .y , data .num_nodes , num_classes , clients , beta , "average"
137- )
138- client_idxs = []
139- train_set = set (train_idx .tolist ())
140- for idxs in split_idxs :
141- ti = [i for i in idxs if i in train_set ]
142- client_idxs .append (torch .tensor (ti , dtype = torch .long ))
143- global_model = TwoLayerGCN (in_channels , num_classes ).to (device )
144- global_params = [p .data .clone () for p in global_model .parameters ()]
145- t0 = time .time ()
146- for _ in range (rounds ):
147- local_params = []
148- for cid in range (clients ):
149- m = TwoLayerGCN (in_channels , num_classes ).to (device )
150- for p , gp in zip (m .parameters (), global_params ):
151- p .data .copy_ (gp )
152- opt = torch .optim .SGD (m .parameters (), lr = LEARNING_RATE )
153- m .train ()
154- opt .zero_grad ()
155- out = m (data )
156- loss = F .cross_entropy (out [client_idxs [cid ]], data .y [client_idxs [cid ]])
157- loss .backward ()
158- opt .step ()
159- local_params .append ([p .data .clone () for p in m .parameters ()])
160- with torch .no_grad ():
161- for gp in global_params :
162- gp .zero_ ()
163- for lp in local_params :
164- for gp , p in zip (global_params , lp ):
165- gp .add_ (p )
166- for gp in global_params :
167- gp .div_ (clients )
168- dur = time .time () - t0
169- for p , gp in zip (global_model .parameters (), global_params ):
170- p .data .copy_ (gp )
171- global_model .eval ()
172- with torch .no_grad ():
173- preds = global_model (data ).argmax (dim = 1 )
174- correct = (
175- (
176- preds [data .test_mask .nonzero (as_tuple = False ).view (- 1 )]
177- == data .y [data .test_mask .nonzero (as_tuple = False ).view (- 1 )]
178- )
179- .sum ()
180- .item ()
181- )
182- acc = correct / data .test_mask .sum ().item ()
183- total_params = sum (p .numel () for p in global_model .parameters ())
184- model_size_mb = total_params * 4 / 1024 ** 2
185- return acc , model_size_mb , total_params , dur
186-
187142
188143def run_fedscope_experiment (ds , beta ):
189144 cfg = global_cfg .clone ()
@@ -194,8 +149,8 @@ def run_fedscope_experiment(ds, beta):
194149 cfg .federate .mode = "standalone"
195150 cfg .federate .client_num = CLIENT_NUM
196151 cfg .federate .total_round_num = TOTAL_ROUNDS
197- cfg .federate .make_global_eval = True
198- cfg .federate .process_num = STANDALONE_PROCESSES
152+ cfg .federate .make_global_eval = False
153+ cfg .federate .process_num = CLIENT_NUM
199154 cfg .federate .num_cpus_per_trainer = CPUS_PER_TRAINER
200155 cfg .data .root = "data/"
201156 cfg .data .type = ds
@@ -214,6 +169,7 @@ def run_fedscope_experiment(ds, beta):
214169 cfg .train .local_update_steps = LOCAL_STEPS
215170 cfg .train .optimizer .lr = LEARNING_RATE
216171 cfg .train .optimizer .weight_decay = 0.0
172+ cfg .train .optimizer .type = "SGD"
217173 cfg .eval .freq = 1
218174 cfg .eval .metrics = ["acc" ]
219175 cfg .freeze ()
@@ -224,11 +180,21 @@ def run_fedscope_experiment(ds, beta):
224180 res = runner .run ()
225181 dur = time .time () - t0
226182 mem = peak_memory_mb ()
227- acc = res .get ("server_global_eval" , res ).get ("test_acc" , res .get ("acc" , 0.0 ))
183+
184+ # 获取FederatedScope结果
185+
186+ # 从FederatedScope的结果中获取准确率
187+ # 使用加权平均以与FedGraph保持一致
188+ acc = res .get ("client_summarized_weighted_avg" , {}).get ("test_acc" , 0.0 ) if res else 0.0
189+
228190 acc_pct = acc * 100 if acc <= 1.0 else acc
229- model = runner .server .model
230- tot_params = sum (p .numel () for p in model .parameters ())
231- msz = tot_params * 4 / 1024 ** 2
191+ model = runner .server .model if runner .server else None
192+ if model is not None :
193+ tot_params = sum (p .numel () for p in model .parameters ())
194+ msz = tot_params * 4 / 1024 ** 2
195+ else :
196+ tot_params = 0
197+ msz = 0.0
232198 comm = calculate_communication_cost (msz , TOTAL_ROUNDS , CLIENT_NUM )
233199 return {
234200 "accuracy" : acc_pct ,
@@ -244,26 +210,6 @@ def run_fedscope_experiment(ds, beta):
244210 }
245211
246212
247- def run_manual_experiment (ds , beta ):
248- if ds == "citeseer" :
249- nodes , edges = 3327 , 9104
250- else :
251- nodes , edges = 19717 , 88648
252- acc , msz , tp , dur = run_fedavg_manual (ds , beta , TOTAL_ROUNDS , CLIENT_NUM )
253- mem = peak_memory_mb ()
254- comm = calculate_communication_cost (msz , TOTAL_ROUNDS , CLIENT_NUM )
255- return {
256- "accuracy" : acc * 100 ,
257- "total_time" : dur ,
258- "computation_time" : dur ,
259- "communication_cost_mb" : comm ,
260- "peak_memory_mb" : mem ,
261- "avg_time_per_round" : dur / TOTAL_ROUNDS ,
262- "model_size_mb" : msz ,
263- "total_params" : tp ,
264- "nodes" : nodes ,
265- "edges" : edges ,
266- }
267213
268214
269215def main ():
@@ -280,8 +226,6 @@ def main():
280226 print (f"Running { ds } with β={ beta } " )
281227 if ds == "cora" :
282228 metrics = run_fedscope_experiment (ds , beta )
283- else :
284- metrics = run_manual_experiment (ds , beta )
285229 print (
286230 f"Dataset: { metrics ['nodes' ]:,} nodes, { metrics ['edges' ]:,} edges"
287231 )
0 commit comments