Skip to content

Commit 941a5b0

Browse files
committed
fixed formatting issue
1 parent 25be758 commit 941a5b0

30 files changed

Lines changed: 2545 additions & 2280 deletions

benchmark/NC.log

Lines changed: 121 additions & 122 deletions
Large diffs are not rendered by default.
Lines changed: 105 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,27 @@
11
#!/usr/bin/env python3
2-
import warnings, logging
3-
warnings.filterwarnings('ignore')
2+
import logging
3+
import warnings
4+
5+
warnings.filterwarnings("ignore")
46
logging.disable(logging.CRITICAL)
57

6-
import argparse, time, resource, torch, torch.nn.functional as F
7-
from torch_geometric.nn import GCNConv
8-
from torch_geometric.datasets import Planetoid
8+
import argparse
9+
import os
10+
import resource
11+
import time
12+
913
import numpy as np
14+
import torch
15+
import torch.nn.functional as F
16+
from torch.distributed import destroy_process_group, init_process_group
17+
from torch.nn.parallel import DistributedDataParallel as DDP
18+
from torch_geometric.datasets import Planetoid
1019

1120
# Distributed PyG imports
1221
from torch_geometric.loader import NeighborLoader
13-
from torch.distributed import init_process_group, destroy_process_group
14-
from torch.nn.parallel import DistributedDataParallel as DDP
15-
import os
22+
from torch_geometric.nn import GCNConv
1623

17-
DATASETS = ['cora', 'citeseer', 'pubmed']
24+
DATASETS = ["cora", "citeseer", "pubmed"]
1825
IID_BETAS = [10000.0, 100.0, 10.0]
1926
CLIENT_NUM = 10
2027
TOTAL_ROUNDS = 200
@@ -23,20 +30,19 @@
2330
HIDDEN_DIM = 64
2431
DROPOUT_RATE = 0.0
2532

26-
PLANETOID_NAMES = {
27-
'cora': 'Cora',
28-
'citeseer': 'CiteSeer',
29-
'pubmed': 'PubMed'
30-
}
33+
PLANETOID_NAMES = {"cora": "Cora", "citeseer": "CiteSeer", "pubmed": "PubMed"}
34+
3135

3236
def peak_memory_mb():
3337
usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
3438
return (usage / 1024**2) if usage > 1024**2 else (usage / 1024)
3539

40+
3641
def calculate_communication_cost(model_size_mb, rounds, clients):
3742
cost_per_round = model_size_mb * clients * 2
3843
return cost_per_round * rounds
3944

45+
4046
def dirichlet_partition(labels, num_clients, alpha):
4147
labels = labels.cpu().numpy()
4248
num_classes = labels.max() + 1
@@ -56,18 +62,21 @@ def dirichlet_partition(labels, num_clients, alpha):
5662

5763
return [torch.tensor(ci, dtype=torch.long) for ci in client_idxs]
5864

65+
5966
class DistributedGCN(torch.nn.Module):
60-
def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2, dropout=0.0):
67+
def __init__(
68+
self, in_channels, hidden_channels, out_channels, num_layers=2, dropout=0.0
69+
):
6170
super().__init__()
6271
self.num_layers = num_layers
6372
self.dropout = dropout
64-
73+
6574
self.convs = torch.nn.ModuleList()
6675
self.convs.append(GCNConv(in_channels, hidden_channels))
6776
for _ in range(num_layers - 2):
6877
self.convs.append(GCNConv(hidden_channels, hidden_channels))
6978
self.convs.append(GCNConv(hidden_channels, out_channels))
70-
79+
7180
def forward(self, x, edge_index):
7281
for i, conv in enumerate(self.convs):
7382
x = conv(x, edge_index)
@@ -76,199 +85,204 @@ def forward(self, x, edge_index):
7685
x = F.dropout(x, p=self.dropout, training=self.training)
7786
return x
7887

88+
7989
def setup_distributed(rank, world_size):
8090
"""Initialize distributed training"""
81-
os.environ['MASTER_ADDR'] = 'localhost'
82-
os.environ['MASTER_PORT'] = '12355'
91+
os.environ["MASTER_ADDR"] = "localhost"
92+
os.environ["MASTER_PORT"] = "12355"
8393
init_process_group("gloo", rank=rank, world_size=world_size)
8494

95+
8596
def cleanup_distributed():
8697
"""Cleanup distributed training"""
8798
destroy_process_group()
8899

100+
89101
def train_client(rank, world_size, data, client_indices, model_state, device):
90102
"""Training function for each client process"""
91103
# Setup distributed environment
92104
setup_distributed(rank, world_size)
93-
105+
94106
# Create model and wrap with DDP
95107
model = DistributedGCN(
96-
data.x.size(1),
97-
HIDDEN_DIM,
98-
int(data.y.max().item()) + 1,
99-
num_layers=2,
100-
dropout=DROPOUT_RATE
108+
data.x.size(1),
109+
HIDDEN_DIM,
110+
int(data.y.max().item()) + 1,
111+
num_layers=2,
112+
dropout=DROPOUT_RATE,
101113
).to(device)
102-
103-
model = DDP(model, device_ids=None if device.type == 'cpu' else [device])
114+
115+
model = DDP(model, device_ids=None if device.type == "cpu" else [device])
104116
model.load_state_dict(model_state)
105-
117+
106118
# Create data loader for this client
107119
loader = NeighborLoader(
108120
data,
109121
input_nodes=client_indices,
110122
num_neighbors=[10, 10],
111123
batch_size=512 if len(client_indices) > 512 else len(client_indices),
112-
shuffle=True
124+
shuffle=True,
113125
)
114-
126+
115127
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
116128
model.train()
117-
129+
118130
# Local training
119131
for epoch in range(LOCAL_STEPS):
120132
total_loss = 0
121133
for batch in loader:
122134
batch = batch.to(device)
123135
optimizer.zero_grad()
124136
out = model(batch.x, batch.edge_index)
125-
137+
126138
# Use only the nodes in the current batch that are in training set
127-
mask = batch.train_mask[:batch.batch_size]
139+
mask = batch.train_mask[: batch.batch_size]
128140
if mask.sum() > 0:
129-
loss = F.cross_entropy(out[:batch.batch_size][mask], batch.y[:batch.batch_size][mask])
141+
loss = F.cross_entropy(
142+
out[: batch.batch_size][mask], batch.y[: batch.batch_size][mask]
143+
)
130144
loss.backward()
131145
optimizer.step()
132146
total_loss += loss.item()
133-
147+
134148
cleanup_distributed()
135149
return model.module.state_dict()
136150

151+
137152
def run_distributed_pyg_experiment(ds, beta):
138-
device = torch.device('cpu') # Use CPU for simplicity
139-
ds_obj = Planetoid(root='data/', name=PLANETOID_NAMES[ds])
153+
device = torch.device("cpu") # Use CPU for simplicity
154+
ds_obj = Planetoid(root="data/", name=PLANETOID_NAMES[ds])
140155
data = ds_obj[0].to(device)
141156
in_channels = data.x.size(1)
142157
num_classes = int(data.y.max().item()) + 1
143-
158+
144159
print(f"Running {ds} with β={beta}")
145160
print(f"Dataset: {data.num_nodes:,} nodes, {data.edge_index.size(1):,} edges")
146-
161+
147162
# Partition training nodes
148163
train_idx = data.train_mask.nonzero(as_tuple=False).view(-1)
149164
test_idx = data.test_mask.nonzero(as_tuple=False).view(-1)
150-
165+
151166
client_parts = dirichlet_partition(data.y[train_idx], CLIENT_NUM, beta)
152167
client_idxs = [train_idx[part] for part in client_parts]
153-
168+
154169
# Initialize global model
155170
global_model = DistributedGCN(
156-
in_channels,
157-
HIDDEN_DIM,
158-
num_classes,
159-
num_layers=2,
160-
dropout=DROPOUT_RATE
171+
in_channels, HIDDEN_DIM, num_classes, num_layers=2, dropout=DROPOUT_RATE
161172
).to(device)
162-
173+
163174
t0 = time.time()
164-
175+
165176
# Federated training loop using simulated distributed training
166177
for round_idx in range(TOTAL_ROUNDS):
167178
global_state = global_model.state_dict()
168179
local_states = []
169-
180+
170181
# Simulate distributed training for each client
171182
for client_id in range(CLIENT_NUM):
172183
# Create client model
173184
client_model = DistributedGCN(
174-
in_channels,
175-
HIDDEN_DIM,
176-
num_classes,
177-
num_layers=2,
178-
dropout=DROPOUT_RATE
185+
in_channels, HIDDEN_DIM, num_classes, num_layers=2, dropout=DROPOUT_RATE
179186
).to(device)
180-
187+
181188
# Load global state
182189
client_model.load_state_dict(global_state)
183-
190+
184191
# Create client data loader using PyG's NeighborLoader
185192
client_loader = NeighborLoader(
186193
data,
187194
input_nodes=client_idxs[client_id],
188195
num_neighbors=[10, 10],
189196
batch_size=min(512, len(client_idxs[client_id])),
190-
shuffle=True
197+
shuffle=True,
191198
)
192-
199+
193200
optimizer = torch.optim.Adam(client_model.parameters(), lr=LEARNING_RATE)
194201
client_model.train()
195-
202+
196203
# Local training
197204
for epoch in range(LOCAL_STEPS):
198205
for batch in client_loader:
199206
batch = batch.to(device)
200207
optimizer.zero_grad()
201208
out = client_model(batch.x, batch.edge_index)
202-
209+
203210
# Use only the nodes that are actually in training set
204-
local_train_mask = torch.isin(batch.n_id[:batch.batch_size], client_idxs[client_id])
211+
local_train_mask = torch.isin(
212+
batch.n_id[: batch.batch_size], client_idxs[client_id]
213+
)
205214
if local_train_mask.sum() > 0:
206215
loss = F.cross_entropy(
207-
out[:batch.batch_size][local_train_mask],
208-
batch.y[:batch.batch_size][local_train_mask]
216+
out[: batch.batch_size][local_train_mask],
217+
batch.y[: batch.batch_size][local_train_mask],
209218
)
210219
loss.backward()
211220
optimizer.step()
212-
221+
213222
local_states.append(client_model.state_dict())
214-
223+
215224
# FedAvg aggregation
216225
global_state = global_model.state_dict()
217226
for key in global_state.keys():
218-
global_state[key] = torch.stack([state[key].float() for state in local_states]).mean(0)
219-
227+
global_state[key] = torch.stack(
228+
[state[key].float() for state in local_states]
229+
).mean(0)
230+
220231
global_model.load_state_dict(global_state)
221-
232+
222233
dur = time.time() - t0
223-
234+
224235
# Final evaluation using NeighborLoader for test set
225236
global_model.eval()
226237
test_loader = NeighborLoader(
227238
data,
228239
input_nodes=test_idx,
229240
num_neighbors=[10, 10],
230241
batch_size=min(1024, len(test_idx)),
231-
shuffle=False
242+
shuffle=False,
232243
)
233-
244+
234245
correct = 0
235246
total = 0
236247
with torch.no_grad():
237248
for batch in test_loader:
238249
batch = batch.to(device)
239250
out = global_model(batch.x, batch.edge_index)
240-
pred = out[:batch.batch_size].argmax(dim=-1)
241-
correct += (pred == batch.y[:batch.batch_size]).sum().item()
251+
pred = out[: batch.batch_size].argmax(dim=-1)
252+
correct += (pred == batch.y[: batch.batch_size]).sum().item()
242253
total += batch.batch_size
243-
254+
244255
accuracy = correct / total * 100
245-
256+
246257
# Calculate metrics
247258
total_params = sum(p.numel() for p in global_model.parameters())
248259
model_size_mb = total_params * 4 / 1024**2
249260
comm_cost = calculate_communication_cost(model_size_mb, TOTAL_ROUNDS, CLIENT_NUM)
250261
mem = peak_memory_mb()
251-
262+
252263
return {
253-
'accuracy': accuracy,
254-
'total_time': dur,
255-
'computation_time': dur,
256-
'communication_cost_mb': comm_cost,
257-
'peak_memory_mb': mem,
258-
'avg_time_per_round': dur / TOTAL_ROUNDS,
259-
'model_size_mb': model_size_mb,
260-
'total_params': total_params,
261-
'nodes': data.num_nodes,
262-
'edges': data.edge_index.size(1)
264+
"accuracy": accuracy,
265+
"total_time": dur,
266+
"computation_time": dur,
267+
"communication_cost_mb": comm_cost,
268+
"peak_memory_mb": mem,
269+
"avg_time_per_round": dur / TOTAL_ROUNDS,
270+
"model_size_mb": model_size_mb,
271+
"total_params": total_params,
272+
"nodes": data.num_nodes,
273+
"edges": data.edge_index.size(1),
263274
}
264275

276+
265277
def main():
266278
parser = argparse.ArgumentParser()
267279
parser.add_argument("--use_cluster", action="store_true")
268280
args = parser.parse_args()
269281

270-
print("\nDS,IID,BS,Time[s],FinalAcc[%],CompTime[s],CommCost[MB],PeakMem[MB],AvgRoundTime[s],ModelSize[MB],TotalParams")
271-
282+
print(
283+
"\nDS,IID,BS,Time[s],FinalAcc[%],CompTime[s],CommCost[MB],PeakMem[MB],AvgRoundTime[s],ModelSize[MB],TotalParams"
284+
)
285+
272286
for ds in DATASETS:
273287
for beta in IID_BETAS:
274288
try:
@@ -288,5 +302,6 @@ def main():
288302
print(f"Error running {ds} with β={beta}: {e}")
289303
print(f"{ds},{beta},-1,0.0,0.00,0.0,0.0,0.0,0.000,0.000,0")
290304

291-
if __name__ == '__main__':
292-
main()
305+
306+
if __name__ == "__main__":
307+
main()

0 commit comments

Comments
 (0)