Skip to content

Commit 69ee808

Browse files
committed
added lowrank to fedavg
1 parent 7ad9c29 commit 69ee808

File tree

6 files changed

+703
-10
lines changed

6 files changed

+703
-10
lines changed

fedgraph/federated_methods.py

Lines changed: 206 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from importlib.resources import files
1111
from pathlib import Path
1212
from typing import Any, List, Optional
13-
1413
import attridict
1514
import numpy as np
1615
import pandas as pd
@@ -33,7 +32,11 @@
3332
)
3433
from fedgraph.utils_nc import get_1hop_feature_sum, save_all_trainers_data
3534

36-
35+
try:
36+
from .low_rank import Server_LowRank, Trainer_General_LowRank
37+
LOWRANK_AVAILABLE = True
38+
except ImportError:
39+
LOWRANK_AVAILABLE = False
3740
def run_fedgraph(args: attridict) -> None:
3841
"""
3942
Run the training process for the specified task.
@@ -50,15 +53,26 @@ def run_fedgraph(args: attridict) -> None:
5053
data: Any
5154
Input data for the federated learning task. Format depends on the specific task and
5255
will be explained in more detail below inside specific functions.
53-
"""
56+
""" # Validate configuration for low-rank compression
57+
if hasattr(args, 'use_lowrank') and args.use_lowrank:
58+
if args.fedgraph_task != "NC":
59+
raise ValueError("Low-rank compression currently only supported for NC tasks")
60+
if args.method != "FedAvg":
61+
raise ValueError("Low-rank compression currently only supported for FedAvg method")
62+
if args.use_encryption:
63+
raise ValueError("Cannot use both encryption and low-rank compression simultaneously")
64+
65+
# Load data
5466
if args.fedgraph_task != "NC" or not args.use_huggingface:
5567
data = data_loader(args)
5668
else:
57-
# use hugging_face instead of use data_loader
58-
print("Using hugging_face for local loading")
5969
data = None
70+
6071
if args.fedgraph_task == "NC":
61-
run_NC(args, data)
72+
if hasattr(args, 'use_lowrank') and args.use_lowrank:
73+
run_NC_lowrank(args, data)
74+
else:
75+
run_NC(args, data)
6276
elif args.fedgraph_task == "GC":
6377
run_GC(args, data)
6478
elif args.fedgraph_task == "LP":
@@ -603,7 +617,193 @@ def get_memory_usage(self):
603617
print(f"{'='*80}\n")
604618
ray.shutdown()
605619

620+
def run_NC_lowrank(args: attridict, data: Any = None) -> None:
621+
622+
if not LOWRANK_AVAILABLE:
623+
raise ImportError("Low-rank compression modules not available. Please implement the low-rank functionality in fedgraph.low_rank")
624+
625+
print("=== Running NC with Low-Rank Compression ===")
626+
print(f"Low-rank method: {getattr(args, 'lowrank_method', 'fixed')}")
627+
if hasattr(args, 'lowrank_method'):
628+
if args.lowrank_method == 'fixed':
629+
print(f"Fixed rank: {getattr(args, 'fixed_rank', 10)}")
630+
elif args.lowrank_method == 'adaptive':
631+
print(f"Target compression ratio: {getattr(args, 'compression_ratio', 2.0)}")
632+
elif args.lowrank_method == 'energy':
633+
print(f"Energy threshold: {getattr(args, 'energy_threshold', 0.95)}")
634+
635+
monitor = Monitor(use_cluster=args.use_cluster)
636+
monitor.init_time_start()
637+
638+
ray.init()
639+
start_time = time.time()
640+
torch.manual_seed(42)
641+
642+
if args.num_hops == 0:
643+
print("Changing method to FedAvg")
644+
args.method = "FedAvg"
645+
646+
if not args.use_huggingface:
647+
(
648+
edge_index, features, labels, idx_train, idx_test, class_num,
649+
split_node_indexes, communicate_node_global_indexes,
650+
in_com_train_node_local_indexes, in_com_test_node_local_indexes,
651+
global_edge_indexes_clients,
652+
) = data
653+
654+
if args.saveto_huggingface:
655+
save_all_trainers_data(
656+
split_node_indexes=split_node_indexes,
657+
communicate_node_global_indexes=communicate_node_global_indexes,
658+
global_edge_indexes_clients=global_edge_indexes_clients,
659+
labels=labels,
660+
features=features,
661+
in_com_train_node_local_indexes=in_com_train_node_local_indexes,
662+
in_com_test_node_local_indexes=in_com_test_node_local_indexes,
663+
n_trainer=args.n_trainer,
664+
args=args,
665+
)
606666

667+
# Model configuration
668+
if args.dataset in ["simulate", "cora", "citeseer", "pubmed", "reddit"]:
669+
args_hidden = 16
670+
else:
671+
args_hidden = 256
672+
673+
# Device configuration
674+
num_cpus_per_trainer = args.num_cpus_per_trainer
675+
if args.gpu:
676+
device = torch.device("cuda")
677+
num_gpus_per_trainer = args.num_gpus_per_trainer
678+
else:
679+
device = torch.device("cpu")
680+
num_gpus_per_trainer = 0
681+
682+
683+
@ray.remote(
684+
num_gpus=num_gpus_per_trainer,
685+
num_cpus=num_cpus_per_trainer,
686+
scheduling_strategy="SPREAD",
687+
)
688+
class Trainer(Trainer_General_LowRank): # Use low-rank trainer instead
689+
def __init__(self, *args: Any, **kwds: Any):
690+
super().__init__(*args, **kwds)
691+
692+
# Create trainers
693+
if args.use_huggingface:
694+
trainers = [
695+
Trainer.remote(
696+
rank=i, args_hidden=args_hidden, device=device, args=args,
697+
)
698+
for i in range(args.n_trainer)
699+
]
700+
else:
701+
trainers = [
702+
Trainer.remote(
703+
rank=i, args_hidden=args_hidden, device=device, args=args,
704+
local_node_index=split_node_indexes[i],
705+
communicate_node_index=communicate_node_global_indexes[i],
706+
adj=global_edge_indexes_clients[i],
707+
train_labels=labels[communicate_node_global_indexes[i]][
708+
in_com_train_node_local_indexes[i]
709+
],
710+
test_labels=labels[communicate_node_global_indexes[i]][
711+
in_com_test_node_local_indexes[i]
712+
],
713+
features=features[split_node_indexes[i]],
714+
idx_train=in_com_train_node_local_indexes[i],
715+
idx_test=in_com_test_node_local_indexes[i],
716+
)
717+
for i in range(args.n_trainer)
718+
]
719+
720+
# Get trainer information
721+
trainer_information = [
722+
ray.get(trainers[i].get_info.remote()) for i in range(len(trainers))
723+
]
724+
725+
global_node_num = sum([info["features_num"] for info in trainer_information])
726+
class_num = max([info["label_num"] for info in trainer_information])
727+
728+
train_data_weights = [
729+
info["len_in_com_train_node_local_indexes"] for info in trainer_information
730+
]
731+
test_data_weights = [
732+
info["len_in_com_test_node_local_indexes"] for info in trainer_information
733+
]
734+
735+
# Initialize models
736+
ray.get([
737+
trainers[i].init_model.remote(global_node_num, class_num)
738+
for i in range(len(trainers))
739+
])
740+
741+
server = Server_LowRank(
742+
features.shape[1], args_hidden, class_num, device, trainers, args
743+
)
744+
# End initialization
745+
server.broadcast_params(-1)
746+
monitor.init_time_end()
747+
748+
749+
monitor.pretrain_time_start()
750+
751+
monitor.pretrain_time_end()
752+
753+
754+
monitor.train_time_start()
755+
print("Starting federated training with low-rank compression...")
756+
757+
global_acc_list = []
758+
for i in range(args.global_rounds):
759+
760+
server.train(i)
761+
762+
# Evaluation
763+
results = [trainer.local_test.remote() for trainer in server.trainers]
764+
results = np.array([ray.get(result) for result in results])
765+
average_test_accuracy = np.average(
766+
[row[1] for row in results], weights=test_data_weights, axis=0
767+
)
768+
global_acc_list.append(average_test_accuracy)
769+
770+
print(f"Round {i+1}: Global Test Accuracy = {average_test_accuracy:.4f}")
771+
772+
# Communication cost tracking (enhanced with compression-aware sizing)
773+
model_size_mb = server.get_model_size() / (1024 * 1024)
774+
monitor.add_train_comm_cost(
775+
upload_mb=model_size_mb * args.n_trainer,
776+
download_mb=model_size_mb * args.n_trainer,
777+
)
778+
779+
if (i + 1) % 10 == 0 and hasattr(server, 'print_compression_stats'):
780+
server.print_compression_stats()
781+
782+
monitor.train_time_end()
783+
784+
# Final evaluation
785+
results = [trainer.local_test.remote() for trainer in server.trainers]
786+
results = np.array([ray.get(result) for result in results])
787+
788+
average_final_test_loss = np.average(
789+
[row[0] for row in results], weights=test_data_weights, axis=0
790+
)
791+
average_final_test_accuracy = np.average(
792+
[row[1] for row in results], weights=test_data_weights, axis=0
793+
)
794+
795+
print(f"Final test loss: {average_final_test_loss:.4f}")
796+
print(f"Final test accuracy: {average_final_test_accuracy:.4f}")
797+
798+
# Print final compression statistics
799+
if hasattr(server, 'print_compression_stats'):
800+
server.print_compression_stats()
801+
802+
if monitor is not None:
803+
monitor.print_comm_cost()
804+
805+
ray.shutdown()
806+
607807
def run_GC(args: attridict, data: Any) -> None:
608808
"""
609809
Entrance of the training process for graph classification.

fedgraph/low_rank/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from .compression_utils import (
2+
svd_compress,
3+
svd_decompress,
4+
calculate_compression_ratio,
5+
auto_select_rank
6+
)
7+
from .server_lowrank import Server_LowRank
8+
from .trainer_lowrank import Trainer_General_LowRank
9+
10+
__all__ = [
11+
'svd_compress',
12+
'svd_decompress',
13+
'calculate_compression_ratio',
14+
'auto_select_rank',
15+
'Server_LowRank',
16+
'Trainer_General_LowRank'
17+
]
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import torch
2+
import numpy as np
3+
from typing import Dict, List, Tuple, Optional, Any
4+
5+
def svd_compress(tensor: torch.Tensor, rank: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
6+
"""
7+
Compress a tensor using SVD decomposition.
8+
9+
Parameters
10+
----------
11+
tensor : torch.Tensor
12+
Input tensor to compress (2D)
13+
rank : int
14+
Target rank for compression
15+
16+
Returns
17+
-------
18+
U, S, V : Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
19+
SVD components with reduced rank
20+
"""
21+
if tensor.dim() != 2:
22+
raise ValueError("SVD compression only supports 2D tensors")
23+
24+
# Perform SVD
25+
U, S, V = torch.svd(tensor)
26+
27+
# Truncate to specified rank
28+
rank = min(rank, min(tensor.shape), len(S))
29+
U_compressed = U[:, :rank]
30+
S_compressed = S[:rank]
31+
V_compressed = V[:, :rank]
32+
33+
return U_compressed, S_compressed, V_compressed
34+
35+
def svd_decompress(U: torch.Tensor, S: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
36+
"""
37+
Reconstruct tensor from SVD components.
38+
39+
Parameters
40+
----------
41+
U, S, V : torch.Tensor
42+
SVD components
43+
44+
Returns
45+
-------
46+
torch.Tensor
47+
Reconstructed tensor
48+
"""
49+
return torch.mm(torch.mm(U, torch.diag(S)), V.t())
50+
51+
def calculate_compression_ratio(original_shape: Tuple[int, int], rank: int) -> float:
52+
"""
53+
Calculate compression ratio for given rank.
54+
55+
Parameters
56+
----------
57+
original_shape : Tuple[int, int]
58+
Shape of original tensor
59+
rank : int
60+
Compression rank
61+
62+
Returns
63+
-------
64+
float
65+
Compression ratio
66+
"""
67+
m, n = original_shape
68+
original_size = m * n
69+
compressed_size = rank * (m + n + 1) # U + S + V
70+
return original_size / compressed_size
71+
72+
def auto_select_rank(tensor: torch.Tensor, compression_ratio: float = 2.0,
73+
energy_threshold: float = 0.95) -> int:
74+
"""
75+
Automatically select rank based on compression ratio or energy preservation.
76+
77+
Parameters
78+
----------
79+
tensor : torch.Tensor
80+
Input tensor
81+
compression_ratio : float
82+
Desired compression ratio
83+
energy_threshold : float
84+
Fraction of energy to preserve
85+
86+
Returns
87+
-------
88+
int
89+
Selected rank
90+
"""
91+
m, n = tensor.shape
92+
max_rank = min(m, n)
93+
94+
# Method 1: Based on compression ratio
95+
target_size = (m * n) / compression_ratio
96+
rank_from_ratio = int((target_size - m - n) / (m + n + 1))
97+
rank_from_ratio = max(1, min(rank_from_ratio, max_rank))
98+
99+
# Method 2: Based on energy preservation
100+
_, S, _ = torch.svd(tensor)
101+
total_energy = torch.sum(S ** 2)
102+
cumulative_energy = torch.cumsum(S ** 2, dim=0)
103+
energy_ratios = cumulative_energy / total_energy
104+
105+
rank_from_energy = torch.sum(energy_ratios < energy_threshold).item() + 1
106+
rank_from_energy = min(rank_from_energy, max_rank)
107+
108+
# Use the more conservative (smaller) rank
109+
return min(rank_from_ratio, rank_from_energy)

0 commit comments

Comments
 (0)