Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 61 additions & 46 deletions leader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,50 +25,54 @@ class LeaderService(leader_pb2_grpc.LeaderServiceServicer):
def __init__(self, learner_count):
self.lock = threading.Lock()
self.max_learners = learner_count
self.learner_list = []
self.learners = {}
self.gradient_list = []
self.data_loader, self.val_loader, self.num_batches = get_data_loader(max_learners=self.max_learners)
self.global_model = model
self.global_optimizer = optimizer_function(self.global_model)
self.device = device
self.num_batches = None
self.learner_count = 0
self.accumulation_count = 0
logging.info("Leader service initialized")

def RegisterLearner(self, request, context):
# lock needed to handle learner_list with thread safety
with self.lock:
logging.info(f"Registering learner: network_addr={request.network_addr}")
if len(self.learner_list) < self.max_learners:
# learners < expected, then register them
new_id = len(self.learner_list)
learner_stub = learner_pb2_grpc.LearnerServiceStub(
grpc.insecure_channel(request.network_addr, options=GRPC_STUB_OPTIONS)
)
# give it the data loader generator
data_loader, num_batches = get_data_loader(learner_id=new_id, max_learners=self.max_learners)
# knowing num_batches allows to know when training is done, all learners should have equal # od batches
self.num_batches = num_batches
self.learner_list.append(
{
'id': new_id,
'network_addr': request.network_addr,
'batches_consumed': 0,
'stub': learner_stub,
'data_loader': data_loader,
}
)
logging.info(f"Registering learner: network_addr={request.network_addr}")
if len(self.learners.values()) < self.max_learners:
# learners < expected, then register them
learner_stub = learner_pb2_grpc.LearnerServiceStub(
grpc.insecure_channel(request.network_addr, options=GRPC_STUB_OPTIONS)
)

with self.lock:
new_id = self.learner_count
self.learner_count += 1

if request.network_addr in self.learners.values():
return leader_pb2.Ack(success=False, message="Learner address already in use")

# knowing num_batches allows to know when training is done, all learners should have equal # od batches
self.learners[request.network_addr] = {
'id': new_id,
'batches_consumed': 0,
'stub': learner_stub,
'network_addr': request.network_addr,
'data_loader': self.data_loader[new_id],
}

with self.lock:
# if learners = expected, then start training on all of them
if len(self.learner_list) == self.max_learners:
if len(self.learners.values()) == self.max_learners:
thread = threading.Thread(target=self.start_training)
thread.start()
return leader_pb2.AckWithMetadata(success=True, message="Registered learner", learner_id=new_id, max_learners=self.max_learners)
else:
return leader_pb2.Ack(success=False, message="Max learners reached")

return leader_pb2.AckWithMetadata(success=True, message="Registered learner", learner_id=new_id, max_learners=self.max_learners)
else:
return leader_pb2.Ack(success=False, message="Max learners reached")

def start_training(self):
time.sleep(3) # Start up time for the last learner
logging.info("Starting training across all registered learners.")
for learner in self.learner_list:
for learner in self.learners.values():
learner['stub'].StartTraining(learner_pb2.Empty())

def GetModel(self, request, context):
Expand All @@ -89,7 +93,7 @@ def GetModel(self, request, context):
def GetData(self, request, context):
logging.info(f"Sending data to learner network_addr {request.network_addr}")
# get learner that requested the data
learner = next((l for l in self.learner_list if l['network_addr'] == request.network_addr), None)
learner = self.learners[request.network_addr]
if not learner:
context.abort(grpc.StatusCode.NOT_FOUND, 'Learner not found')
return
Expand All @@ -101,21 +105,23 @@ def GetData(self, request, context):
end_index = min(start_index + 10, self.num_batches)

# check if there are batches left to send
if start_index < self.num_batches:
for i, (input, labels) in enumerate(learner['data_loader']):
# send only data batches that fall in the start to end index range
if start_index <= i < end_index:
buffer = io.BytesIO()
torch.save((input, labels), buffer)
buffer.seek(0)
serialized_batch = buffer.read()
yield leader_pb2.DataChunk(chunk=serialized_batch)
learner['batches_consumed'] += 1

logging.info(f"Learner {learner['id']} is getting batches {start_index + 1} to {end_index} of {self.num_batches}")
else:
if start_index >= self.num_batches:
# when nothing is sent back to learner, the learner gracefully quits itself
logging.info(f"No more batches to send to learner {learner['id']}.")

for i, (input, labels) in enumerate(learner['data_loader']):
# send only data batches that fall in the start to end index range
if start_index <= i < end_index:
buffer = io.BytesIO()
torch.save((input, labels), buffer)
buffer.seek(0)
serialized_batch = buffer.read()
yield leader_pb2.DataChunk(chunk=serialized_batch)
learner['batches_consumed'] += 1

logging.info(f"Learner {learner['network_addr']} is getting batches {start_index + 1} to {end_index} of {self.num_batches}")



def AccumulateGradients(self, request, context):
# needs to be thread safe for gradient_list to handle accumulations properly
Expand All @@ -125,7 +131,7 @@ def AccumulateGradients(self, request, context):
gradients = torch.load(buffer)
self.gradient_list.append(gradients)

if len(self.gradient_list) == len(self.learner_list):
if len(self.gradient_list) == len(self.learners.values()):
self.update_and_broadcast_model()
self.gradient_list = []

Expand Down Expand Up @@ -179,7 +185,7 @@ def broadcast_accumulated_gradient(self, model_state: dict):
thread.start()

# send model state back to learners for sync
for learner in self.learner_list:
for learner in self.learners.values():
serialized_model_state = learner_pb2.ModelState(chunk=serialized_model_state_dict)
learner['stub'].SyncModelState(serialized_model_state)

Expand All @@ -189,7 +195,7 @@ def run_model_validation(self):
with torch.no_grad():
correct = 0
total = 0
for inputs, labels in get_data_loader(valid=True):
for inputs, labels in self.val_loader:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = self.global_model(inputs)
Expand All @@ -205,6 +211,15 @@ def run_model_validation(self):

# program ends
os._exit(0)

def DropLearner(self, request, context):
with self.lock:
network_addr = request.chunk.decode("utf-8")
logging.info(f"Dropping learner: network_addr={network_addr}")
# Removes learner
self.learners.pop(network_addr, None)
return leader_pb2.Ack(success=True, message="Learner dropped")


def serve(learner_count):
logging.info("Starting leader service...")
Expand Down
5 changes: 3 additions & 2 deletions learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def SyncModelState(self, request, context):

return learner_pb2.Ack(success=True, message="Model state synchronized successfully")

def serve(network_addr, learner_port, leader_stub):
def serve(network_addr, learner_port, leader_stub, learner_info):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=5), options=GRPC_STUB_OPTIONS)
learner_service = LearnerService(network_addr, leader_stub)
learner_pb2_grpc.add_LearnerServiceServicer_to_server(learner_service, server)
Expand All @@ -164,6 +164,7 @@ def serve(network_addr, learner_port, leader_stub):
while True:
time.sleep(86400)
except KeyboardInterrupt:
leader_stub.DropLearner(learner_info)
server.stop(0)

def parse_args():
Expand All @@ -187,6 +188,6 @@ def parse_args():
logging.info('Registering learner...')
is_registered = leader_stub.RegisterLearner(learner_info)
if is_registered.success:
serve(network_addr, learner_port, leader_stub)
serve(network_addr, learner_port, leader_stub, learner_info)
else:
logging.error('Registering learner unsuccessful')
5 changes: 5 additions & 0 deletions protos/leader.proto
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ service LeaderService {
rpc GetModel(Empty) returns (stream ModelChunk) {};
rpc GetData(LearnerDataRequest) returns (stream DataChunk) {};
rpc AccumulateGradients(GradientData) returns (Ack) {};
rpc DropLearner(LearnerInfo) returns (AckWithMetadata) {};
}

message LearnerInfo {
Expand Down Expand Up @@ -41,4 +42,8 @@ message DataChunk {

message GradientData {
bytes chunk = 1;
}

message DropLearner {
string network_addr = 1;
}
33 changes: 32 additions & 1 deletion protos/leader_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ def __init__(self, channel):
request_serializer=protos_dot_leader__pb2.GradientData.SerializeToString,
response_deserializer=protos_dot_leader__pb2.Ack.FromString,
)
self.DropLearner = channel.unary_unary(
'/leader.LeaderService/DropLearner',
request_serializer=protos_dot_leader__pb2.LearnerInfo.SerializeToString,
response_deserializer=protos_dot_leader__pb2.AckWithMetadata.FromString,
)


class LeaderServiceServicer(object):
Expand Down Expand Up @@ -62,7 +67,11 @@ def AccumulateGradients(self, request, context):
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def DropLearner(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def add_LeaderServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
Expand All @@ -86,6 +95,11 @@ def add_LeaderServiceServicer_to_server(servicer, server):
request_deserializer=protos_dot_leader__pb2.GradientData.FromString,
response_serializer=protos_dot_leader__pb2.Ack.SerializeToString,
),
'DropLearner': grpc.unary_unary_rpc_method_handler(
servicer.DropLearner,
request_deserializer=protos_dot_leader__pb2.GradientData.FromString,
response_serializer=protos_dot_leader__pb2.Ack.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'leader.LeaderService', rpc_method_handlers)
Expand Down Expand Up @@ -163,3 +177,20 @@ def AccumulateGradients(request,
protos_dot_leader__pb2.Ack.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def DropLearner(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/leader.LeaderService/DropLearner',
protos_dot_leader__pb2.LearnerInfo.SerializeToString,
protos_dot_leader__pb2.AckWithMetadata.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
33 changes: 17 additions & 16 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def load_dataset(data_dir, train=True, transform=None):
"""Load a dataset."""
return datasets.CIFAR10(root=data_dir, train=train, download=True, transform=transform)

def get_data_loader(data_dir='./data', test=False, valid=False, learner_id=0, max_learners=1):
def get_data_loader(data_dir='./data', test=False, max_learners=1):
"""Create and return data loaders for training and validation/test, equally divided among learners."""
transform = get_transform()
if test:
Expand All @@ -42,21 +42,22 @@ def get_data_loader(data_dir='./data', test=False, valid=False, learner_id=0, ma
# Distribute indices evenly among learners
per_learner = len(train_idx) // max_learners
extra_samples = len(train_idx) % max_learners
start_idx = learner_id * per_learner + min(learner_id, extra_samples)
end_idx = start_idx + per_learner + (1 if learner_id < extra_samples else 0)

# Calculate the number of batches
num_samples_for_learner = end_idx - start_idx
num_batches = num_samples_for_learner // 32 + (num_samples_for_learner % 32 > 0)

# Use Subset to directly slice the dataset without sampling
train_subset = Subset(dataset, train_idx[start_idx:end_idx])

train_loader = []

for i in range(max_learners):
start_idx = i * per_learner + min(i, extra_samples)
end_idx = start_idx + per_learner + (1 if i < extra_samples else 0)

# Calculate the number of batches
num_samples_for_learner = end_idx - start_idx
num_batches = num_samples_for_learner // 32 + (num_samples_for_learner % 32 > 0)

# Use Subset to directly slice the dataset without sampling
train_subset = Subset(dataset, train_idx[start_idx:end_idx])
train_loader.append(DataLoader(train_subset, batch_size=32))

valid_subset = Subset(dataset, valid_idx)

train_loader = DataLoader(train_subset, batch_size=32)
valid_loader = DataLoader(valid_subset, batch_size=32)

if valid:
return valid_loader
else:
return train_loader, num_batches
return train_loader, valid_loader, num_batches