11#!/usr/bin/env python3
2- import warnings , logging
3- warnings .filterwarnings ('ignore' )
2+ import logging
3+ import warnings
4+
5+ warnings .filterwarnings ("ignore" )
46logging .disable (logging .CRITICAL )
57
6- import argparse , time , resource , torch , torch .nn .functional as F
7- from torch_geometric .nn import GCNConv
8- from torch_geometric .datasets import Planetoid
8+ import argparse
9+ import os
10+ import resource
11+ import time
12+
913import numpy as np
14+ import torch
15+ import torch .nn .functional as F
16+ from torch .distributed import destroy_process_group , init_process_group
17+ from torch .nn .parallel import DistributedDataParallel as DDP
18+ from torch_geometric .datasets import Planetoid
1019
1120# Distributed PyG imports
1221from torch_geometric .loader import NeighborLoader
13- from torch .distributed import init_process_group , destroy_process_group
14- from torch .nn .parallel import DistributedDataParallel as DDP
15- import os
22+ from torch_geometric .nn import GCNConv
1623
17- DATASETS = [' cora' , ' citeseer' , ' pubmed' ]
24+ DATASETS = [" cora" , " citeseer" , " pubmed" ]
1825IID_BETAS = [10000.0 , 100.0 , 10.0 ]
1926CLIENT_NUM = 10
2027TOTAL_ROUNDS = 200
2330HIDDEN_DIM = 64
2431DROPOUT_RATE = 0.0
2532
26- PLANETOID_NAMES = {
27- 'cora' : 'Cora' ,
28- 'citeseer' : 'CiteSeer' ,
29- 'pubmed' : 'PubMed'
30- }
33+ PLANETOID_NAMES = {"cora" : "Cora" , "citeseer" : "CiteSeer" , "pubmed" : "PubMed" }
34+
3135
3236def peak_memory_mb ():
3337 usage = resource .getrusage (resource .RUSAGE_SELF ).ru_maxrss
3438 return (usage / 1024 ** 2 ) if usage > 1024 ** 2 else (usage / 1024 )
3539
40+
3641def calculate_communication_cost (model_size_mb , rounds , clients ):
3742 cost_per_round = model_size_mb * clients * 2
3843 return cost_per_round * rounds
3944
45+
4046def dirichlet_partition (labels , num_clients , alpha ):
4147 labels = labels .cpu ().numpy ()
4248 num_classes = labels .max () + 1
@@ -56,18 +62,21 @@ def dirichlet_partition(labels, num_clients, alpha):
5662
5763 return [torch .tensor (ci , dtype = torch .long ) for ci in client_idxs ]
5864
65+
5966class DistributedGCN (torch .nn .Module ):
60- def __init__ (self , in_channels , hidden_channels , out_channels , num_layers = 2 , dropout = 0.0 ):
67+ def __init__ (
68+ self , in_channels , hidden_channels , out_channels , num_layers = 2 , dropout = 0.0
69+ ):
6170 super ().__init__ ()
6271 self .num_layers = num_layers
6372 self .dropout = dropout
64-
73+
6574 self .convs = torch .nn .ModuleList ()
6675 self .convs .append (GCNConv (in_channels , hidden_channels ))
6776 for _ in range (num_layers - 2 ):
6877 self .convs .append (GCNConv (hidden_channels , hidden_channels ))
6978 self .convs .append (GCNConv (hidden_channels , out_channels ))
70-
79+
7180 def forward (self , x , edge_index ):
7281 for i , conv in enumerate (self .convs ):
7382 x = conv (x , edge_index )
@@ -76,199 +85,204 @@ def forward(self, x, edge_index):
7685 x = F .dropout (x , p = self .dropout , training = self .training )
7786 return x
7887
88+
7989def setup_distributed (rank , world_size ):
8090 """Initialize distributed training"""
81- os .environ [' MASTER_ADDR' ] = ' localhost'
82- os .environ [' MASTER_PORT' ] = ' 12355'
91+ os .environ [" MASTER_ADDR" ] = " localhost"
92+ os .environ [" MASTER_PORT" ] = " 12355"
8393 init_process_group ("gloo" , rank = rank , world_size = world_size )
8494
95+
8596def cleanup_distributed ():
8697 """Cleanup distributed training"""
8798 destroy_process_group ()
8899
100+
89101def train_client (rank , world_size , data , client_indices , model_state , device ):
90102 """Training function for each client process"""
91103 # Setup distributed environment
92104 setup_distributed (rank , world_size )
93-
105+
94106 # Create model and wrap with DDP
95107 model = DistributedGCN (
96- data .x .size (1 ),
97- HIDDEN_DIM ,
98- int (data .y .max ().item ()) + 1 ,
99- num_layers = 2 ,
100- dropout = DROPOUT_RATE
108+ data .x .size (1 ),
109+ HIDDEN_DIM ,
110+ int (data .y .max ().item ()) + 1 ,
111+ num_layers = 2 ,
112+ dropout = DROPOUT_RATE ,
101113 ).to (device )
102-
103- model = DDP (model , device_ids = None if device .type == ' cpu' else [device ])
114+
115+ model = DDP (model , device_ids = None if device .type == " cpu" else [device ])
104116 model .load_state_dict (model_state )
105-
117+
106118 # Create data loader for this client
107119 loader = NeighborLoader (
108120 data ,
109121 input_nodes = client_indices ,
110122 num_neighbors = [10 , 10 ],
111123 batch_size = 512 if len (client_indices ) > 512 else len (client_indices ),
112- shuffle = True
124+ shuffle = True ,
113125 )
114-
126+
115127 optimizer = torch .optim .Adam (model .parameters (), lr = LEARNING_RATE )
116128 model .train ()
117-
129+
118130 # Local training
119131 for epoch in range (LOCAL_STEPS ):
120132 total_loss = 0
121133 for batch in loader :
122134 batch = batch .to (device )
123135 optimizer .zero_grad ()
124136 out = model (batch .x , batch .edge_index )
125-
137+
126138 # Use only the nodes in the current batch that are in training set
127- mask = batch .train_mask [:batch .batch_size ]
139+ mask = batch .train_mask [: batch .batch_size ]
128140 if mask .sum () > 0 :
129- loss = F .cross_entropy (out [:batch .batch_size ][mask ], batch .y [:batch .batch_size ][mask ])
141+ loss = F .cross_entropy (
142+ out [: batch .batch_size ][mask ], batch .y [: batch .batch_size ][mask ]
143+ )
130144 loss .backward ()
131145 optimizer .step ()
132146 total_loss += loss .item ()
133-
147+
134148 cleanup_distributed ()
135149 return model .module .state_dict ()
136150
151+
137152def run_distributed_pyg_experiment (ds , beta ):
138- device = torch .device (' cpu' ) # Use CPU for simplicity
139- ds_obj = Planetoid (root = ' data/' , name = PLANETOID_NAMES [ds ])
153+ device = torch .device (" cpu" ) # Use CPU for simplicity
154+ ds_obj = Planetoid (root = " data/" , name = PLANETOID_NAMES [ds ])
140155 data = ds_obj [0 ].to (device )
141156 in_channels = data .x .size (1 )
142157 num_classes = int (data .y .max ().item ()) + 1
143-
158+
144159 print (f"Running { ds } with β={ beta } " )
145160 print (f"Dataset: { data .num_nodes :,} nodes, { data .edge_index .size (1 ):,} edges" )
146-
161+
147162 # Partition training nodes
148163 train_idx = data .train_mask .nonzero (as_tuple = False ).view (- 1 )
149164 test_idx = data .test_mask .nonzero (as_tuple = False ).view (- 1 )
150-
165+
151166 client_parts = dirichlet_partition (data .y [train_idx ], CLIENT_NUM , beta )
152167 client_idxs = [train_idx [part ] for part in client_parts ]
153-
168+
154169 # Initialize global model
155170 global_model = DistributedGCN (
156- in_channels ,
157- HIDDEN_DIM ,
158- num_classes ,
159- num_layers = 2 ,
160- dropout = DROPOUT_RATE
171+ in_channels , HIDDEN_DIM , num_classes , num_layers = 2 , dropout = DROPOUT_RATE
161172 ).to (device )
162-
173+
163174 t0 = time .time ()
164-
175+
165176 # Federated training loop using simulated distributed training
166177 for round_idx in range (TOTAL_ROUNDS ):
167178 global_state = global_model .state_dict ()
168179 local_states = []
169-
180+
170181 # Simulate distributed training for each client
171182 for client_id in range (CLIENT_NUM ):
172183 # Create client model
173184 client_model = DistributedGCN (
174- in_channels ,
175- HIDDEN_DIM ,
176- num_classes ,
177- num_layers = 2 ,
178- dropout = DROPOUT_RATE
185+ in_channels , HIDDEN_DIM , num_classes , num_layers = 2 , dropout = DROPOUT_RATE
179186 ).to (device )
180-
187+
181188 # Load global state
182189 client_model .load_state_dict (global_state )
183-
190+
184191 # Create client data loader using PyG's NeighborLoader
185192 client_loader = NeighborLoader (
186193 data ,
187194 input_nodes = client_idxs [client_id ],
188195 num_neighbors = [10 , 10 ],
189196 batch_size = min (512 , len (client_idxs [client_id ])),
190- shuffle = True
197+ shuffle = True ,
191198 )
192-
199+
193200 optimizer = torch .optim .Adam (client_model .parameters (), lr = LEARNING_RATE )
194201 client_model .train ()
195-
202+
196203 # Local training
197204 for epoch in range (LOCAL_STEPS ):
198205 for batch in client_loader :
199206 batch = batch .to (device )
200207 optimizer .zero_grad ()
201208 out = client_model (batch .x , batch .edge_index )
202-
209+
203210 # Use only the nodes that are actually in training set
204- local_train_mask = torch .isin (batch .n_id [:batch .batch_size ], client_idxs [client_id ])
211+ local_train_mask = torch .isin (
212+ batch .n_id [: batch .batch_size ], client_idxs [client_id ]
213+ )
205214 if local_train_mask .sum () > 0 :
206215 loss = F .cross_entropy (
207- out [:batch .batch_size ][local_train_mask ],
208- batch .y [:batch .batch_size ][local_train_mask ]
216+ out [: batch .batch_size ][local_train_mask ],
217+ batch .y [: batch .batch_size ][local_train_mask ],
209218 )
210219 loss .backward ()
211220 optimizer .step ()
212-
221+
213222 local_states .append (client_model .state_dict ())
214-
223+
215224 # FedAvg aggregation
216225 global_state = global_model .state_dict ()
217226 for key in global_state .keys ():
218- global_state [key ] = torch .stack ([state [key ].float () for state in local_states ]).mean (0 )
219-
227+ global_state [key ] = torch .stack (
228+ [state [key ].float () for state in local_states ]
229+ ).mean (0 )
230+
220231 global_model .load_state_dict (global_state )
221-
232+
222233 dur = time .time () - t0
223-
234+
224235 # Final evaluation using NeighborLoader for test set
225236 global_model .eval ()
226237 test_loader = NeighborLoader (
227238 data ,
228239 input_nodes = test_idx ,
229240 num_neighbors = [10 , 10 ],
230241 batch_size = min (1024 , len (test_idx )),
231- shuffle = False
242+ shuffle = False ,
232243 )
233-
244+
234245 correct = 0
235246 total = 0
236247 with torch .no_grad ():
237248 for batch in test_loader :
238249 batch = batch .to (device )
239250 out = global_model (batch .x , batch .edge_index )
240- pred = out [:batch .batch_size ].argmax (dim = - 1 )
241- correct += (pred == batch .y [:batch .batch_size ]).sum ().item ()
251+ pred = out [: batch .batch_size ].argmax (dim = - 1 )
252+ correct += (pred == batch .y [: batch .batch_size ]).sum ().item ()
242253 total += batch .batch_size
243-
254+
244255 accuracy = correct / total * 100
245-
256+
246257 # Calculate metrics
247258 total_params = sum (p .numel () for p in global_model .parameters ())
248259 model_size_mb = total_params * 4 / 1024 ** 2
249260 comm_cost = calculate_communication_cost (model_size_mb , TOTAL_ROUNDS , CLIENT_NUM )
250261 mem = peak_memory_mb ()
251-
262+
252263 return {
253- ' accuracy' : accuracy ,
254- ' total_time' : dur ,
255- ' computation_time' : dur ,
256- ' communication_cost_mb' : comm_cost ,
257- ' peak_memory_mb' : mem ,
258- ' avg_time_per_round' : dur / TOTAL_ROUNDS ,
259- ' model_size_mb' : model_size_mb ,
260- ' total_params' : total_params ,
261- ' nodes' : data .num_nodes ,
262- ' edges' : data .edge_index .size (1 )
264+ " accuracy" : accuracy ,
265+ " total_time" : dur ,
266+ " computation_time" : dur ,
267+ " communication_cost_mb" : comm_cost ,
268+ " peak_memory_mb" : mem ,
269+ " avg_time_per_round" : dur / TOTAL_ROUNDS ,
270+ " model_size_mb" : model_size_mb ,
271+ " total_params" : total_params ,
272+ " nodes" : data .num_nodes ,
273+ " edges" : data .edge_index .size (1 ),
263274 }
264275
276+
265277def main ():
266278 parser = argparse .ArgumentParser ()
267279 parser .add_argument ("--use_cluster" , action = "store_true" )
268280 args = parser .parse_args ()
269281
270- print ("\n DS,IID,BS,Time[s],FinalAcc[%],CompTime[s],CommCost[MB],PeakMem[MB],AvgRoundTime[s],ModelSize[MB],TotalParams" )
271-
282+ print (
283+ "\n DS,IID,BS,Time[s],FinalAcc[%],CompTime[s],CommCost[MB],PeakMem[MB],AvgRoundTime[s],ModelSize[MB],TotalParams"
284+ )
285+
272286 for ds in DATASETS :
273287 for beta in IID_BETAS :
274288 try :
@@ -288,5 +302,6 @@ def main():
288302 print (f"Error running { ds } with β={ beta } : { e } " )
289303 print (f"{ ds } ,{ beta } ,-1,0.0,0.00,0.0,0.0,0.0,0.000,0.000,0" )
290304
291- if __name__ == '__main__' :
292- main ()
305+
306+ if __name__ == "__main__" :
307+ main ()
0 commit comments