1010from importlib .resources import files
1111from pathlib import Path
1212from typing import Any , List , Optional
13-
1413import attridict
1514import numpy as np
1615import pandas as pd
3332)
3433from fedgraph .utils_nc import get_1hop_feature_sum , save_all_trainers_data
3534
36-
35+ try :
36+ from .low_rank import Server_LowRank , Trainer_General_LowRank
37+ LOWRANK_AVAILABLE = True
38+ except ImportError :
39+ LOWRANK_AVAILABLE = False
3740def run_fedgraph (args : attridict ) -> None :
3841 """
3942 Run the training process for the specified task.
@@ -50,15 +53,26 @@ def run_fedgraph(args: attridict) -> None:
5053 data: Any
5154 Input data for the federated learning task. Format depends on the specific task and
5255 will be explained in more detail below inside specific functions.
53- """
56+ """ # Validate configuration for low-rank compression
57+ if hasattr (args , 'use_lowrank' ) and args .use_lowrank :
58+ if args .fedgraph_task != "NC" :
59+ raise ValueError ("Low-rank compression currently only supported for NC tasks" )
60+ if args .method != "FedAvg" :
61+ raise ValueError ("Low-rank compression currently only supported for FedAvg method" )
62+ if args .use_encryption :
63+ raise ValueError ("Cannot use both encryption and low-rank compression simultaneously" )
64+
65+ # Load data
5466 if args .fedgraph_task != "NC" or not args .use_huggingface :
5567 data = data_loader (args )
5668 else :
57- # use hugging_face instead of use data_loader
58- print ("Using hugging_face for local loading" )
5969 data = None
70+
6071 if args .fedgraph_task == "NC" :
61- run_NC (args , data )
72+ if hasattr (args , 'use_lowrank' ) and args .use_lowrank :
73+ run_NC_lowrank (args , data )
74+ else :
75+ run_NC (args , data )
6276 elif args .fedgraph_task == "GC" :
6377 run_GC (args , data )
6478 elif args .fedgraph_task == "LP" :
@@ -603,7 +617,193 @@ def get_memory_usage(self):
603617 print (f"{ '=' * 80 } \n " )
604618 ray .shutdown ()
605619
620+ def run_NC_lowrank (args : attridict , data : Any = None ) -> None :
621+
622+ if not LOWRANK_AVAILABLE :
623+ raise ImportError ("Low-rank compression modules not available. Please implement the low-rank functionality in fedgraph.low_rank" )
624+
625+ print ("=== Running NC with Low-Rank Compression ===" )
626+ print (f"Low-rank method: { getattr (args , 'lowrank_method' , 'fixed' )} " )
627+ if hasattr (args , 'lowrank_method' ):
628+ if args .lowrank_method == 'fixed' :
629+ print (f"Fixed rank: { getattr (args , 'fixed_rank' , 10 )} " )
630+ elif args .lowrank_method == 'adaptive' :
631+ print (f"Target compression ratio: { getattr (args , 'compression_ratio' , 2.0 )} " )
632+ elif args .lowrank_method == 'energy' :
633+ print (f"Energy threshold: { getattr (args , 'energy_threshold' , 0.95 )} " )
634+
635+ monitor = Monitor (use_cluster = args .use_cluster )
636+ monitor .init_time_start ()
637+
638+ ray .init ()
639+ start_time = time .time ()
640+ torch .manual_seed (42 )
641+
642+ if args .num_hops == 0 :
643+ print ("Changing method to FedAvg" )
644+ args .method = "FedAvg"
645+
646+ if not args .use_huggingface :
647+ (
648+ edge_index , features , labels , idx_train , idx_test , class_num ,
649+ split_node_indexes , communicate_node_global_indexes ,
650+ in_com_train_node_local_indexes , in_com_test_node_local_indexes ,
651+ global_edge_indexes_clients ,
652+ ) = data
653+
654+ if args .saveto_huggingface :
655+ save_all_trainers_data (
656+ split_node_indexes = split_node_indexes ,
657+ communicate_node_global_indexes = communicate_node_global_indexes ,
658+ global_edge_indexes_clients = global_edge_indexes_clients ,
659+ labels = labels ,
660+ features = features ,
661+ in_com_train_node_local_indexes = in_com_train_node_local_indexes ,
662+ in_com_test_node_local_indexes = in_com_test_node_local_indexes ,
663+ n_trainer = args .n_trainer ,
664+ args = args ,
665+ )
606666
667+ # Model configuration
668+ if args .dataset in ["simulate" , "cora" , "citeseer" , "pubmed" , "reddit" ]:
669+ args_hidden = 16
670+ else :
671+ args_hidden = 256
672+
673+ # Device configuration
674+ num_cpus_per_trainer = args .num_cpus_per_trainer
675+ if args .gpu :
676+ device = torch .device ("cuda" )
677+ num_gpus_per_trainer = args .num_gpus_per_trainer
678+ else :
679+ device = torch .device ("cpu" )
680+ num_gpus_per_trainer = 0
681+
682+
683+ @ray .remote (
684+ num_gpus = num_gpus_per_trainer ,
685+ num_cpus = num_cpus_per_trainer ,
686+ scheduling_strategy = "SPREAD" ,
687+ )
688+ class Trainer (Trainer_General_LowRank ): # Use low-rank trainer instead
689+ def __init__ (self , * args : Any , ** kwds : Any ):
690+ super ().__init__ (* args , ** kwds )
691+
692+ # Create trainers
693+ if args .use_huggingface :
694+ trainers = [
695+ Trainer .remote (
696+ rank = i , args_hidden = args_hidden , device = device , args = args ,
697+ )
698+ for i in range (args .n_trainer )
699+ ]
700+ else :
701+ trainers = [
702+ Trainer .remote (
703+ rank = i , args_hidden = args_hidden , device = device , args = args ,
704+ local_node_index = split_node_indexes [i ],
705+ communicate_node_index = communicate_node_global_indexes [i ],
706+ adj = global_edge_indexes_clients [i ],
707+ train_labels = labels [communicate_node_global_indexes [i ]][
708+ in_com_train_node_local_indexes [i ]
709+ ],
710+ test_labels = labels [communicate_node_global_indexes [i ]][
711+ in_com_test_node_local_indexes [i ]
712+ ],
713+ features = features [split_node_indexes [i ]],
714+ idx_train = in_com_train_node_local_indexes [i ],
715+ idx_test = in_com_test_node_local_indexes [i ],
716+ )
717+ for i in range (args .n_trainer )
718+ ]
719+
720+ # Get trainer information
721+ trainer_information = [
722+ ray .get (trainers [i ].get_info .remote ()) for i in range (len (trainers ))
723+ ]
724+
725+ global_node_num = sum ([info ["features_num" ] for info in trainer_information ])
726+ class_num = max ([info ["label_num" ] for info in trainer_information ])
727+
728+ train_data_weights = [
729+ info ["len_in_com_train_node_local_indexes" ] for info in trainer_information
730+ ]
731+ test_data_weights = [
732+ info ["len_in_com_test_node_local_indexes" ] for info in trainer_information
733+ ]
734+
735+ # Initialize models
736+ ray .get ([
737+ trainers [i ].init_model .remote (global_node_num , class_num )
738+ for i in range (len (trainers ))
739+ ])
740+
741+ server = Server_LowRank (
742+ features .shape [1 ], args_hidden , class_num , device , trainers , args
743+ )
744+ # End initialization
745+ server .broadcast_params (- 1 )
746+ monitor .init_time_end ()
747+
748+
749+ monitor .pretrain_time_start ()
750+
751+ monitor .pretrain_time_end ()
752+
753+
754+ monitor .train_time_start ()
755+ print ("Starting federated training with low-rank compression..." )
756+
757+ global_acc_list = []
758+ for i in range (args .global_rounds ):
759+
760+ server .train (i )
761+
762+ # Evaluation
763+ results = [trainer .local_test .remote () for trainer in server .trainers ]
764+ results = np .array ([ray .get (result ) for result in results ])
765+ average_test_accuracy = np .average (
766+ [row [1 ] for row in results ], weights = test_data_weights , axis = 0
767+ )
768+ global_acc_list .append (average_test_accuracy )
769+
770+ print (f"Round { i + 1 } : Global Test Accuracy = { average_test_accuracy :.4f} " )
771+
772+ # Communication cost tracking (enhanced with compression-aware sizing)
773+ model_size_mb = server .get_model_size () / (1024 * 1024 )
774+ monitor .add_train_comm_cost (
775+ upload_mb = model_size_mb * args .n_trainer ,
776+ download_mb = model_size_mb * args .n_trainer ,
777+ )
778+
779+ if (i + 1 ) % 10 == 0 and hasattr (server , 'print_compression_stats' ):
780+ server .print_compression_stats ()
781+
782+ monitor .train_time_end ()
783+
784+ # Final evaluation
785+ results = [trainer .local_test .remote () for trainer in server .trainers ]
786+ results = np .array ([ray .get (result ) for result in results ])
787+
788+ average_final_test_loss = np .average (
789+ [row [0 ] for row in results ], weights = test_data_weights , axis = 0
790+ )
791+ average_final_test_accuracy = np .average (
792+ [row [1 ] for row in results ], weights = test_data_weights , axis = 0
793+ )
794+
795+ print (f"Final test loss: { average_final_test_loss :.4f} " )
796+ print (f"Final test accuracy: { average_final_test_accuracy :.4f} " )
797+
798+ # Print final compression statistics
799+ if hasattr (server , 'print_compression_stats' ):
800+ server .print_compression_stats ()
801+
802+ if monitor is not None :
803+ monitor .print_comm_cost ()
804+
805+ ray .shutdown ()
806+
607807def run_GC (args : attridict , data : Any ) -> None :
608808 """
609809 Entrance of the training process for graph classification.
0 commit comments