From 8194944b1004b22fe22318d1f3801fed2c323e92 Mon Sep 17 00:00:00 2001 From: Christopher McGale <56483395+chrismcgale@users.noreply.github.com> Date: Thu, 18 Apr 2024 14:39:39 -0400 Subject: [PATCH 1/2] Reduced get_data_loader calls and locking time --- leader.py | 74 +++++++++++++++++++++++++++---------------------------- utils.py | 33 +++++++++++++------------ 2 files changed, 53 insertions(+), 54 deletions(-) diff --git a/leader.py b/leader.py index 60f176e..5e4ad27 100644 --- a/leader.py +++ b/leader.py @@ -27,43 +27,39 @@ def __init__(self, learner_count): self.max_learners = learner_count self.learner_list = [] self.gradient_list = [] + self.data_loader, self.val_loader, self.num_batches = get_data_loader(max_learners=self.max_learners) self.global_model = model self.global_optimizer = optimizer_function(self.global_model) self.device = device - self.num_batches = None self.accumulation_count = 0 logging.info("Leader service initialized") def RegisterLearner(self, request, context): - # lock needed to handle learner_list with thread safety - with self.lock: - logging.info(f"Registering learner: network_addr={request.network_addr}") - if len(self.learner_list) < self.max_learners: - # learners < expected, then register them + logging.info(f"Registering learner: network_addr={request.network_addr}") + if len(self.learner_list) < self.max_learners: + # learners < expected, then register them + new_id = len(self.learner_list) + learner_stub = learner_pb2_grpc.LearnerServiceStub( + grpc.insecure_channel(request.network_addr, options=GRPC_STUB_OPTIONS) + ) + new_id = 0 + with self.lock: new_id = len(self.learner_list) - learner_stub = learner_pb2_grpc.LearnerServiceStub( - grpc.insecure_channel(request.network_addr, options=GRPC_STUB_OPTIONS) - ) - # give it the data loader generator - data_loader, num_batches = get_data_loader(learner_id=new_id, max_learners=self.max_learners) # knowing num_batches allows to know when training is done, all learners should have equal # od batches - self.num_batches = num_batches - self.learner_list.append( - { + self.learner_list.append({ 'id': new_id, 'network_addr': request.network_addr, 'batches_consumed': 0, 'stub': learner_stub, - 'data_loader': data_loader, - } - ) - # if learners = expected, then start training on all of them - if len(self.learner_list) == self.max_learners: - thread = threading.Thread(target=self.start_training) - thread.start() - return leader_pb2.AckWithMetadata(success=True, message="Registered learner", learner_id=new_id, max_learners=self.max_learners) - else: - return leader_pb2.Ack(success=False, message="Max learners reached") + 'data_loader': self.data_loader[new_id], + }) + # if learners = expected, then start training on all of them + if new_id + 1 == self.max_learners: + thread = threading.Thread(target=self.start_training) + thread.start() + return leader_pb2.AckWithMetadata(success=True, message="Registered learner", learner_id=new_id, max_learners=self.max_learners) + else: + return leader_pb2.Ack(success=False, message="Max learners reached") def start_training(self): time.sleep(3) # Start up time for the last learner @@ -101,21 +97,23 @@ def GetData(self, request, context): end_index = min(start_index + 10, self.num_batches) # check if there are batches left to send - if start_index < self.num_batches: - for i, (input, labels) in enumerate(learner['data_loader']): - # send only data batches that fall in the start to end index range - if start_index <= i < end_index: - buffer = io.BytesIO() - torch.save((input, labels), buffer) - buffer.seek(0) - serialized_batch = buffer.read() - yield leader_pb2.DataChunk(chunk=serialized_batch) - learner['batches_consumed'] += 1 - - logging.info(f"Learner {learner['id']} is getting batches {start_index + 1} to {end_index} of {self.num_batches}") - else: + if start_index >= self.num_batches: # when nothing is sent back to learner, the learner gracefully quits itself logging.info(f"No more batches to send to learner {learner['id']}.") + + for i, (input, labels) in enumerate(learner['data_loader']): + # send only data batches that fall in the start to end index range + if start_index <= i < end_index: + buffer = io.BytesIO() + torch.save((input, labels), buffer) + buffer.seek(0) + serialized_batch = buffer.read() + yield leader_pb2.DataChunk(chunk=serialized_batch) + learner['batches_consumed'] += 1 + + logging.info(f"Learner {learner['id']} is getting batches {start_index + 1} to {end_index} of {self.num_batches}") + + def AccumulateGradients(self, request, context): # needs to be thread safe for gradient_list to handle accumulations properly @@ -189,7 +187,7 @@ def run_model_validation(self): with torch.no_grad(): correct = 0 total = 0 - for inputs, labels in get_data_loader(valid=True): + for inputs, labels in self.val_loader: inputs = inputs.to(device) labels = labels.to(device) outputs = self.global_model(inputs) diff --git a/utils.py b/utils.py index 3d1fdfb..f0969d2 100644 --- a/utils.py +++ b/utils.py @@ -20,7 +20,7 @@ def load_dataset(data_dir, train=True, transform=None): """Load a dataset.""" return datasets.CIFAR10(root=data_dir, train=train, download=True, transform=transform) -def get_data_loader(data_dir='./data', test=False, valid=False, learner_id=0, max_learners=1): +def get_data_loader(data_dir='./data', test=False, max_learners=1): """Create and return data loaders for training and validation/test, equally divided among learners.""" transform = get_transform() if test: @@ -42,21 +42,22 @@ def get_data_loader(data_dir='./data', test=False, valid=False, learner_id=0, ma # Distribute indices evenly among learners per_learner = len(train_idx) // max_learners extra_samples = len(train_idx) % max_learners - start_idx = learner_id * per_learner + min(learner_id, extra_samples) - end_idx = start_idx + per_learner + (1 if learner_id < extra_samples else 0) - - # Calculate the number of batches - num_samples_for_learner = end_idx - start_idx - num_batches = num_samples_for_learner // 32 + (num_samples_for_learner % 32 > 0) - - # Use Subset to directly slice the dataset without sampling - train_subset = Subset(dataset, train_idx[start_idx:end_idx]) + + train_loader = [] + + for i in range(max_learners): + start_idx = i * per_learner + min(i, extra_samples) + end_idx = start_idx + per_learner + (1 if i < extra_samples else 0) + + # Calculate the number of batches + num_samples_for_learner = end_idx - start_idx + num_batches = num_samples_for_learner // 32 + (num_samples_for_learner % 32 > 0) + + # Use Subset to directly slice the dataset without sampling + train_subset = Subset(dataset, train_idx[start_idx:end_idx]) + train_loader.append(DataLoader(train_subset, batch_size=32)) + valid_subset = Subset(dataset, valid_idx) - - train_loader = DataLoader(train_subset, batch_size=32) valid_loader = DataLoader(valid_subset, batch_size=32) - if valid: - return valid_loader - else: - return train_loader, num_batches + return train_loader, valid_loader, num_batches From ea806ea581c5662abd29b0bf7d79ea84f781ac66 Mon Sep 17 00:00:00 2001 From: Christopher McGale <56483395+chrismcgale@users.noreply.github.com> Date: Thu, 18 Apr 2024 16:06:56 -0400 Subject: [PATCH 2/2] Seamless execution on node failure --- leader.py | 61 +++++++++++++++++++++++++-------------- learner.py | 5 ++-- protos/leader.proto | 5 ++++ protos/leader_pb2_grpc.py | 33 ++++++++++++++++++++- 4 files changed, 79 insertions(+), 25 deletions(-) diff --git a/leader.py b/leader.py index 5e4ad27..f485966 100644 --- a/leader.py +++ b/leader.py @@ -25,38 +25,46 @@ class LeaderService(leader_pb2_grpc.LeaderServiceServicer): def __init__(self, learner_count): self.lock = threading.Lock() self.max_learners = learner_count - self.learner_list = [] + self.learners = {} self.gradient_list = [] self.data_loader, self.val_loader, self.num_batches = get_data_loader(max_learners=self.max_learners) self.global_model = model self.global_optimizer = optimizer_function(self.global_model) self.device = device + self.learner_count = 0 self.accumulation_count = 0 logging.info("Leader service initialized") def RegisterLearner(self, request, context): logging.info(f"Registering learner: network_addr={request.network_addr}") - if len(self.learner_list) < self.max_learners: + if len(self.learners.values()) < self.max_learners: # learners < expected, then register them - new_id = len(self.learner_list) learner_stub = learner_pb2_grpc.LearnerServiceStub( grpc.insecure_channel(request.network_addr, options=GRPC_STUB_OPTIONS) ) - new_id = 0 + + with self.lock: + new_id = self.learner_count + self.learner_count += 1 + + if request.network_addr in self.learners.values(): + return leader_pb2.Ack(success=False, message="Learner address already in use") + + # knowing num_batches allows to know when training is done, all learners should have equal # od batches + self.learners[request.network_addr] = { + 'id': new_id, + 'batches_consumed': 0, + 'stub': learner_stub, + 'network_addr': request.network_addr, + 'data_loader': self.data_loader[new_id], + } + with self.lock: - new_id = len(self.learner_list) - # knowing num_batches allows to know when training is done, all learners should have equal # od batches - self.learner_list.append({ - 'id': new_id, - 'network_addr': request.network_addr, - 'batches_consumed': 0, - 'stub': learner_stub, - 'data_loader': self.data_loader[new_id], - }) - # if learners = expected, then start training on all of them - if new_id + 1 == self.max_learners: - thread = threading.Thread(target=self.start_training) - thread.start() + # if learners = expected, then start training on all of them + if len(self.learners.values()) == self.max_learners: + thread = threading.Thread(target=self.start_training) + thread.start() + return leader_pb2.AckWithMetadata(success=True, message="Registered learner", learner_id=new_id, max_learners=self.max_learners) else: return leader_pb2.Ack(success=False, message="Max learners reached") @@ -64,7 +72,7 @@ def RegisterLearner(self, request, context): def start_training(self): time.sleep(3) # Start up time for the last learner logging.info("Starting training across all registered learners.") - for learner in self.learner_list: + for learner in self.learners.values(): learner['stub'].StartTraining(learner_pb2.Empty()) def GetModel(self, request, context): @@ -85,7 +93,7 @@ def GetModel(self, request, context): def GetData(self, request, context): logging.info(f"Sending data to learner network_addr {request.network_addr}") # get learner that requested the data - learner = next((l for l in self.learner_list if l['network_addr'] == request.network_addr), None) + learner = self.learners[request.network_addr] if not learner: context.abort(grpc.StatusCode.NOT_FOUND, 'Learner not found') return @@ -111,7 +119,7 @@ def GetData(self, request, context): yield leader_pb2.DataChunk(chunk=serialized_batch) learner['batches_consumed'] += 1 - logging.info(f"Learner {learner['id']} is getting batches {start_index + 1} to {end_index} of {self.num_batches}") + logging.info(f"Learner {learner['network_addr']} is getting batches {start_index + 1} to {end_index} of {self.num_batches}") @@ -123,7 +131,7 @@ def AccumulateGradients(self, request, context): gradients = torch.load(buffer) self.gradient_list.append(gradients) - if len(self.gradient_list) == len(self.learner_list): + if len(self.gradient_list) == len(self.learners.values()): self.update_and_broadcast_model() self.gradient_list = [] @@ -177,7 +185,7 @@ def broadcast_accumulated_gradient(self, model_state: dict): thread.start() # send model state back to learners for sync - for learner in self.learner_list: + for learner in self.learners.values(): serialized_model_state = learner_pb2.ModelState(chunk=serialized_model_state_dict) learner['stub'].SyncModelState(serialized_model_state) @@ -203,6 +211,15 @@ def run_model_validation(self): # program ends os._exit(0) + + def DropLearner(self, request, context): + with self.lock: + network_addr = request.chunk.decode("utf-8") + logging.info(f"Dropping learner: network_addr={network_addr}") + # Removes learner + self.learners.pop(network_addr, None) + return leader_pb2.Ack(success=True, message="Learner dropped") + def serve(learner_count): logging.info("Starting leader service...") diff --git a/learner.py b/learner.py index 236ae23..60309a6 100644 --- a/learner.py +++ b/learner.py @@ -152,7 +152,7 @@ def SyncModelState(self, request, context): return learner_pb2.Ack(success=True, message="Model state synchronized successfully") -def serve(network_addr, learner_port, leader_stub): +def serve(network_addr, learner_port, leader_stub, learner_info): server = grpc.server(futures.ThreadPoolExecutor(max_workers=5), options=GRPC_STUB_OPTIONS) learner_service = LearnerService(network_addr, leader_stub) learner_pb2_grpc.add_LearnerServiceServicer_to_server(learner_service, server) @@ -164,6 +164,7 @@ def serve(network_addr, learner_port, leader_stub): while True: time.sleep(86400) except KeyboardInterrupt: + leader_stub.DropLearner(learner_info) server.stop(0) def parse_args(): @@ -187,6 +188,6 @@ def parse_args(): logging.info('Registering learner...') is_registered = leader_stub.RegisterLearner(learner_info) if is_registered.success: - serve(network_addr, learner_port, leader_stub) + serve(network_addr, learner_port, leader_stub, learner_info) else: logging.error('Registering learner unsuccessful') diff --git a/protos/leader.proto b/protos/leader.proto index 0b2eb91..8d3f21a 100644 --- a/protos/leader.proto +++ b/protos/leader.proto @@ -7,6 +7,7 @@ service LeaderService { rpc GetModel(Empty) returns (stream ModelChunk) {}; rpc GetData(LearnerDataRequest) returns (stream DataChunk) {}; rpc AccumulateGradients(GradientData) returns (Ack) {}; + rpc DropLearner(LearnerInfo) returns (AckWithMetadata) {}; } message LearnerInfo { @@ -41,4 +42,8 @@ message DataChunk { message GradientData { bytes chunk = 1; +} + +message DropLearner { + string network_addr = 1; } \ No newline at end of file diff --git a/protos/leader_pb2_grpc.py b/protos/leader_pb2_grpc.py index 4102152..4112889 100644 --- a/protos/leader_pb2_grpc.py +++ b/protos/leader_pb2_grpc.py @@ -34,6 +34,11 @@ def __init__(self, channel): request_serializer=protos_dot_leader__pb2.GradientData.SerializeToString, response_deserializer=protos_dot_leader__pb2.Ack.FromString, ) + self.DropLearner = channel.unary_unary( + '/leader.LeaderService/DropLearner', + request_serializer=protos_dot_leader__pb2.LearnerInfo.SerializeToString, + response_deserializer=protos_dot_leader__pb2.AckWithMetadata.FromString, + ) class LeaderServiceServicer(object): @@ -62,7 +67,11 @@ def AccumulateGradients(self, request, context): context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') - + def DropLearner(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def add_LeaderServiceServicer_to_server(servicer, server): rpc_method_handlers = { @@ -86,6 +95,11 @@ def add_LeaderServiceServicer_to_server(servicer, server): request_deserializer=protos_dot_leader__pb2.GradientData.FromString, response_serializer=protos_dot_leader__pb2.Ack.SerializeToString, ), + 'DropLearner': grpc.unary_unary_rpc_method_handler( + servicer.DropLearner, + request_deserializer=protos_dot_leader__pb2.GradientData.FromString, + response_serializer=protos_dot_leader__pb2.Ack.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( 'leader.LeaderService', rpc_method_handlers) @@ -163,3 +177,20 @@ def AccumulateGradients(request, protos_dot_leader__pb2.Ack.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def DropLearner(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/leader.LeaderService/DropLearner', + protos_dot_leader__pb2.LearnerInfo.SerializeToString, + protos_dot_leader__pb2.AckWithMetadata.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata)