diff --git a/openfl-workspace/flower-app-pytorch/.workspace b/openfl-workspace/flower-app-pytorch/.workspace new file mode 100644 index 0000000000..3c2c5d08b4 --- /dev/null +++ b/openfl-workspace/flower-app-pytorch/.workspace @@ -0,0 +1,2 @@ +current_plan_name: default + diff --git a/openfl-workspace/flower-app-pytorch/README.md b/openfl-workspace/flower-app-pytorch/README.md new file mode 100644 index 0000000000..2886d8922c --- /dev/null +++ b/openfl-workspace/flower-app-pytorch/README.md @@ -0,0 +1,280 @@ +# Open(FL)ower + +This workspace demonstrates a new functionality in OpenFL to interoperate with [Flower](https://flower.ai/). In particular, a user can now use the Flower API to run on OpenFL infrastructure. OpenFL will act as an intermediary step between the Flower SuperLink and Flower SuperNode to relay messages across the network using OpenFL's transport mechanisms. + +## Overview + +In this repository, you'll notice a directory under `src` called `app-pytorch`. This is essentially a Flower PyTorch app created using Flower's `flwr new` command that has been modified to run a local federation. The `client_app.py` and `server_app.py` dictate what will be run by the client and server respectively. `task.py` defines the logic that will be executed by each app, such as the model definition, train/test tasks, etc. Under `server_app.py` a section titled "Save Model" is added in order to save the `best.pbuf` and `last.pbuf` models from the experiment in your local workspace under `./save`. This uses native OpenFL logic to store the model as a `.pbuf` in order to later be retrieved by `fx model save` into a native format (limited to `.npz` to be deep learning framework agnostic), but this can be overridden to save the model directly following Flower's recommended method for [saving model checkpoints](https://flower.ai/docs/framework/how-to-save-and-load-model-checkpoints.html). + +## Execution Methods + +There are two ways to execute this: + +1. Automatic shutdown which will spawn a `server-app` in isolation and trigger an experiment termination once the it shuts down. (Default/Recommended) +2. Running `SuperLink` and `SuperNode` as [long-lived components](#long-lived-superlink-and-supernode) that will indefinitely wait for new runs. (Limited Functionality) + +## Getting Started + +### Install OpenFL + +Create virtual env +```sh +pip install virtualenv +virtualenv ./venv +source ./venv/bin/activate +``` + +Install OpenFL from source +```sh +git clone https://github.com/securefederatedai/openfl.git +cd openfl +pip install -e . +``` + +### Create a Workspace + +Start by creating a workspace: + +```sh +fx workspace create --template flower-app-pytorch --prefix my_workspace +cd my_workspace +``` + +This will create a workspace in your current working directory called `./my_workspace` as well as install the Flower app defined in `./app-pytorch.` This will be where the experiment takes place. + +### Configure the Experiment +Notice under `./plan`, you will find the familiar OpenFL YAML files to configure the experiment. `col.yaml` and `data.yaml` will be populated by the collaborators that will run the Flower client app and the respective data shard or directory they will perform their training and testing on. +`plan.yaml` configures the experiment itself. The Open-Flower integration makes a few key changes to the `plan.yaml`: + +1. Introduction of a new top-level key (`connector`) to configure a newly introduced component called `Connector`. Specifically, the Flower integration uses a `Connector` subclass called `ConnectorFlower`. This component is run by the aggregator and is responsible for initializing the Flower `SuperLink` and connecting to the OpenFL server. The `SuperLink` parameters can be configured using `connector.settings.superlink_params`. If nothing is supplied, it will simply run `flower-superlink --insecure` with the command's default settings as dictated by Flower. It also includes the option to run the flwr run command via `connector.settings.flwr_run_params`. If `flwr_run_params` are not provided, the user will be expected to run `flwr run ` from the aggregator machine to initiate the experiment. + +```yaml +connector: + defaults: plan/defaults/connector.yaml + template: openfl.component.ConnectorFlower + settings: + superlink_params: + insecure: True + serverappio-api-address: 127.0.0.1:9091 + fleet-api-address: 127.0.0.1:9092 + exec-api-address: 127.0.0.1:9093 + flwr_run_params: + flwr_app_name: "app-pytorch" + federation_name: "local-poc" +``` + +2. `ConnectorAssigner` and tasks designed to explicitly run `start_client_adapter` task for every authorized collaborator, which is defined by the Task Runner. + +```yaml +assigner: + defaults: plan/defaults/assigner.yaml + template: openfl.component.ConnectorAssigner + settings: + task_groups: + - name: Connector_Flower + tasks: + - start_client_adapter +``` + +3. `FlowerTaskRunner` which will execute the `start_client_adapter` task. This task starts the Flower SuperNode and makes a connection to the OpenFL client. Additionally, the `FlowerTaskRunner` has an additional setting `FlowerTaskRunner.settings.auto_shutdown` which is default set to `True`. When set to `True`, the task runner will shut the SuperNode at the completion of an experiment, otherwise, it will run continuously. + +```yaml +task_runner: + defaults: plan/defaults/task_runner.yaml + template: openfl.federated.task.runner_flower.FlowerTaskRunner + settings: + auto_shutdown: True +``` +3. `FlowerDataLoader` with similar high-level functionality to other dataloaders. + +**IMPORTANT NOTE**: `aggregator.settings.rounds_to_train` is set to 1. __Do not edit this__. The actual number of rounds for the experiment is controlled by Flower logic inside of `./app-pytorch/pyproject.toml`. The entirety of the Flower experiment will run in a single OpenFL round. The aggregator round is there to stop the OpenFL components at the completion of the experiment. + +## Running the Workspace +Run the workspace as normal (certify the workspace, initialize the plan, register the collaborators, etc.): + +```SH +# Generate a Certificate Signing Request (CSR) for the Aggregator +fx aggregator generate-cert-request + +# The CA signs the aggregator's request, which is now available in the workspace +fx aggregator certify --silent + +# Initialize FL Plan and Model Weights for the Federation +fx plan initialize + +################################ +# Setup Collaborator 1 +################################ + +# Create a collaborator named "collaborator1" that will use shard "0" +fx collaborator create -n collaborator1 -d 0 + +# Generate a CSR for collaborator1 +fx collaborator generate-cert-request -n collaborator1 + +# The CA signs collaborator1's certificate +fx collaborator certify -n collaborator1 --silent + +################################ +# Setup Collaborator 2 +################################ + +# Create a collaborator named "collaborator2" that will use shard "1" +fx collaborator create -n collaborator2 -d 1 + +# Generate a CSR for collaborator2 +fx collaborator generate-cert-request -n collaborator2 + +# The CA signs collaborator2's certificate +fx collaborator certify -n collaborator2 --silent + +############################## +# Start to Run the Federation +############################## + +# Run the Aggregator +fx aggregator start +``` + +This will prepare the workspace and start the OpenFL aggregator, Flower superlink, and Flower serverapp. You should see something like: + +```SH +INFO 🧿 Starting the Aggregator Service. +. +. +. +INFO : Starting Flower SuperLink +WARNING : Option `--insecure` was set. Starting insecure HTTP server. +INFO : Flower Deployment Engine: Starting Exec API on 127.0.0.1:9093 +INFO : Flower ECE: Starting ServerAppIo API (gRPC-rere) on 127.0.0.1:9091 +INFO : Flower ECE: Starting Fleet API (GrpcAdapter) on 127.0.0.1:9092 +. +. +. +INFO : [INIT] +INFO : Using initial global parameters provided by strategy +INFO : Starting evaluation of initial global parameters +INFO : Evaluation returned no results (`None`) +INFO : +INFO : [ROUND 1] +``` + +### Start Collaborators +Open 2 additional terminals for collaborators. +For collaborator 1's terminal, run: +```SH +fx collaborator start -n collaborator1 +``` +For collaborator 2's terminal, run: +```SH +fx collaborator start -n collaborator2 +``` +This will start the collaborator nodes, the Flower `SuperNode`, and Flower `ClientApp`, and begin running the Flower experiment. You should see something like: + +```SH + INFO 🧿 Starting a Collaborator Service. +. +. +. +INFO : Starting Flower SuperNode +WARNING : Option `--insecure` was set. Starting insecure HTTP channel to 127.0.0.1:... +INFO : Starting Flower ClientAppIo gRPC server on 127.0.0.1:... +INFO : +INFO : [RUN 297994661073077505, ROUND 1] +``` +### Completion of the Experiment +Upon the completion of the experiment, on the `aggregator` terminal, the Flower components should send an experiment summary as the `SuperLink `continues to receive requests from the supernode: +```SH +INFO : [SUMMARY] +INFO : Run finished 3 round(s) in 93.29s +INFO : History (loss, distributed): +INFO : round 1: 2.0937052175497555 +INFO : round 2: 1.8027011854633406 +INFO : round 3: 1.6812996898487116 +``` +If `automatic_shutdown` is enabled, this will be shortly followed by the OpenFL `aggregator` receiving "results" from the `collaborator` and subsequently shutting down: + +```SH +INFO Round 0: Collaborators that have completed all tasks: ['collaborator1', 'collaborator2'] +INFO Experiment Completed. Cleaning up... +INFO Sending signal to collaborator collaborator2 to shutdown... +INFO Sending signal to collaborator collaborator1 to shutdown... +INFO [OpenFL Connector] Stopping server process with PID: ... +INFO : SuperLink terminated gracefully. +INFO [OpenFL Connector] Server process stopped. +``` +Upon the completion of the experiment, on the `collaborator` terminals, the Flower components should be outputting the information about the run: + +```SH +INFO : [RUN ..., ROUND 3] +INFO : Received: evaluate message +INFO : Start `flwr-clientapp` process +INFO : [flwr-clientapp] Pull `ClientAppInputs` for token ... +INFO : [flwr-clientapp] Push `ClientAppOutputs` for token ... +``` + +If `automatic_shutdown` is enabled, this will be shortly followed by the OpenFL `collaborator` shutting down: + +```SH +INFO : SuperNode terminated gracefully. +INFO SuperNode process terminated. +INFO Shutting down local gRPC server... +INFO local gRPC server stopped. +INFO Waiting for tasks... +INFO Received shutdown signal. Exiting... +``` +Congratulations, you have run a Flower experiment through OpenFL's task runner! + +## Advanced Usage +### Long-lived SuperLink and SuperNode +A user can set `automatic_shutdown: False` in the `Connector` settings of the `plan.yaml`. + +```yaml +connector : + defaults : plan/defaults/connector.yaml + template : openfl.component.ConnectorFlower + settings : + automatic_shutdown: False +``` + +By doing so, Flower's `ServerApp` and `ClientApp` will still shut down at the completion of the Flower experiment, but the `SuperLink` and `SuperNode` will continue to run. As a result, on the `aggregator` terminal, you will see a constant request coming from the `SuperNode`: + +```SH +INFO : GrpcAdapter.PullTaskIns +INFO : GrpcAdapter.PullTaskIns +INFO : GrpcAdapter.PullTaskIns +``` +You can run another experiment by opening another terminal, navigating to this workspace, and running: +```SH +flwr run ./src/app-pytorch +``` +It will run another experiment. Once you are done, you can manually shut down OpenFL's `collaborator` and Flower's `SuperNode` with `CTRL+C`. This will trigger a task-completion by the task runner that'll subsequently begin the graceful shutdown process of the OpenFL and Flower components. + +### Running in SGX Enclave +Gramine does not support all Linux system calls. Flower FAB is built and installed at runtime. During this, `utime()` is called, which is an [unsupported call](https://gramine.readthedocs.io/en/latest/devel/features.html#list-of-system-calls), resulting in error or unexpected behavior. To navigate this, when running in an SGX enclave, we opt to build and install the FAB during initialization and package it alongside the OpenFL workspace. To make this work, we introduce some patches to Flower's build command. In addition, since secure enclaves have strict read/write permissions, dictate by a set of trusted/allowed files, we also patch Flower's telemetry command in order to consolidate written file locations. + +To run these patches, simply add `patch: True` to the `Connector` and `Task Runner` settings. For the `Task Runner` also include the name of the Flower app for building and installation. + +```yaml +connector : + defaults : plan/defaults/connector.yaml + template : openfl.component.ConnectorFlower + settings : + superlink_params : + insecure : True + serverappio-api-address : 127.0.0.1:9091 + fleet-api-address : 127.0.0.1:9092 + exec-api-address : 127.0.0.1:9093 + patch : True + flwr_run_params : + flwr_app_name : "app-pytorch" + federation_name : "local-poc" + patch : True + +task_runner : + defaults : plan/defaults/task_runner.yaml + template : openfl.federated.task.runner_flower.FlowerTaskRunner + settings : + patch : True + flwr_app_name : "app-pytorch" +``` \ No newline at end of file diff --git a/openfl-workspace/flower-app-pytorch/plan/cols.yaml b/openfl-workspace/flower-app-pytorch/plan/cols.yaml new file mode 100644 index 0000000000..024e743dcd --- /dev/null +++ b/openfl-workspace/flower-app-pytorch/plan/cols.yaml @@ -0,0 +1,5 @@ +# Copyright (C) 2024 Intel Corporation +# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. + +collaborators: + \ No newline at end of file diff --git a/openfl-workspace/flower-app-pytorch/plan/data.yaml b/openfl-workspace/flower-app-pytorch/plan/data.yaml new file mode 100644 index 0000000000..6bca42b213 --- /dev/null +++ b/openfl-workspace/flower-app-pytorch/plan/data.yaml @@ -0,0 +1,2 @@ +# Copyright (C) 2024 Intel Corporation +# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. diff --git a/openfl-workspace/flower-app-pytorch/plan/plan.yaml b/openfl-workspace/flower-app-pytorch/plan/plan.yaml new file mode 100644 index 0000000000..18c7e1dd8b --- /dev/null +++ b/openfl-workspace/flower-app-pytorch/plan/plan.yaml @@ -0,0 +1,58 @@ +# Copyright (C) 2024 Intel Corporation +# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. + +aggregator : + defaults : plan/defaults/aggregator.yaml + template : openfl.component.Aggregator + settings : + rounds_to_train : 1 # DO NOT EDIT. This is to indicate OpenFL communication rounds + persist_checkpoint : false + write_logs : false + +connector : + defaults : plan/defaults/connector.yaml + template : openfl.component.ConnectorFlower + settings : + superlink_params : + insecure : True + serverappio-api-address : 127.0.0.1:9091 # note [kta-intel]: ServerApp will connect here + fleet-api-address : 127.0.0.1:9092 # note [kta-intel]: local gRPC client will connect here + exec-api-address : 127.0.0.1:9093 # note [kta-intel]: port for server-app toml (for flwr run) + flwr_run_params : + flwr_app_name : "app-pytorch" + federation_name : "local-poc" + +collaborator : + defaults : plan/defaults/collaborator.yaml + template : openfl.component.Collaborator + +data_loader : + defaults : plan/defaults/data_loader.yaml + template : openfl.federated.data.loader_flower.FlowerDataLoader + settings : + collaborator_count : 2 + +task_runner : + defaults : plan/defaults/task_runner.yaml + template : openfl.federated.task.runner_flower.FlowerTaskRunner + +network : + defaults : plan/defaults/network.yaml + +assigner : + defaults : plan/defaults/assigner.yaml + template : openfl.component.RandomGroupedAssigner + settings : + task_groups : + - name : Connector_Flower + percentage : 1.0 + tasks : + - start_client_adapter + +tasks : + defaults : plan/defaults/tasks_connector.yaml + settings : + connect_to : Flower + +compression_pipeline : + defaults : plan/defaults/compression_pipeline.yaml \ No newline at end of file diff --git a/openfl-workspace/flower-app-pytorch/requirements.txt b/openfl-workspace/flower-app-pytorch/requirements.txt new file mode 100644 index 0000000000..016dbec06b --- /dev/null +++ b/openfl-workspace/flower-app-pytorch/requirements.txt @@ -0,0 +1 @@ +./src/app-pytorch \ No newline at end of file diff --git a/openfl-workspace/flower-app-pytorch/src/app-pytorch/app_pytorch/__init__.py b/openfl-workspace/flower-app-pytorch/src/app-pytorch/app_pytorch/__init__.py new file mode 100644 index 0000000000..bb8f979717 --- /dev/null +++ b/openfl-workspace/flower-app-pytorch/src/app-pytorch/app_pytorch/__init__.py @@ -0,0 +1 @@ +"""app-pytorch: A Flower / PyTorch app.""" diff --git a/openfl-workspace/flower-app-pytorch/src/app-pytorch/app_pytorch/client_app.py b/openfl-workspace/flower-app-pytorch/src/app-pytorch/app_pytorch/client_app.py new file mode 100644 index 0000000000..38f8b01047 --- /dev/null +++ b/openfl-workspace/flower-app-pytorch/src/app-pytorch/app_pytorch/client_app.py @@ -0,0 +1,54 @@ +"""app-pytorch: A Flower / PyTorch app.""" + +import torch + +from flwr.client import ClientApp, NumPyClient +from flwr.common import Context +from app_pytorch.task import Net, get_weights, load_data, set_weights, test, train + + +# Define Flower Client and client_fn +class FlowerClient(NumPyClient): + def __init__(self, net, trainloader, valloader, local_epochs): + self.net = net + self.trainloader = trainloader + self.valloader = valloader + self.local_epochs = local_epochs + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.net.to(self.device) + + def fit(self, parameters, config): + set_weights(self.net, parameters) + train_loss = train( + self.net, + self.trainloader, + self.local_epochs, + self.device, + ) + return ( + get_weights(self.net), + len(self.trainloader.dataset), + {"train_loss": train_loss}, + ) + + def evaluate(self, parameters, config): + set_weights(self.net, parameters) + loss, accuracy = test(self.net, self.valloader, self.device) + return loss, len(self.valloader.dataset), {"accuracy": accuracy} + + +def client_fn(context: Context): + # Load model and data + net = Net() + data_path = context.node_config["data-path"] + trainloader, valloader = load_data(data_path) + local_epochs = context.run_config["local-epochs"] + + # Return Client instance + return FlowerClient(net, trainloader, valloader, local_epochs).to_client() + + +# Flower ClientApp +app = ClientApp( + client_fn, +) diff --git a/openfl-workspace/flower-app-pytorch/src/app-pytorch/app_pytorch/server_app.py b/openfl-workspace/flower-app-pytorch/src/app-pytorch/app_pytorch/server_app.py new file mode 100644 index 0000000000..4610346e0b --- /dev/null +++ b/openfl-workspace/flower-app-pytorch/src/app-pytorch/app_pytorch/server_app.py @@ -0,0 +1,104 @@ +"""app-pytorch: A Flower / PyTorch app.""" + +from flwr.common import Context, ndarrays_to_parameters +from flwr.server import ServerApp, ServerAppComponents, ServerConfig +from flwr.server.strategy import FedAvg +from app_pytorch.task import Net, get_weights + +from flwr.server.client_proxy import ClientProxy +from flwr.common import FitRes, EvaluateRes, Scalar, Parameters, parameters_to_ndarrays +from typing import Optional, Union, OrderedDict, List, Tuple +import numpy as np +from flwr.server.strategy.aggregate import weighted_loss_avg +from flwr.common.logger import log +from logging import WARNING + +class SaveModelStrategy(FedAvg): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.largest_loss = 1e9 + self.aggregated_ndarrays = None + + def aggregate_fit( + self, + server_round: int, + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: + """Aggregate model weights using weighted average and store checkpoint""" + + # Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics + aggregated_parameters, aggregated_metrics = super().aggregate_fit( + server_round, results, failures + ) + + if aggregated_parameters is not None: + # Convert `Parameters` to `list[np.ndarray]` + self.aggregated_ndarrays: list[np.ndarray] = parameters_to_ndarrays( + aggregated_parameters + ) + + np.savez(f"last.npz", *self.aggregated_ndarrays) + + return aggregated_parameters, aggregated_metrics + + def aggregate_evaluate( + self, + server_round: int, + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> tuple[Optional[float], dict[str, Scalar]]: + """Aggregate evaluation losses using weighted average.""" + if not results: + return None, {} + # Do not aggregate if there are failures and failures are not accepted + if not self.accept_failures and failures: + return None, {} + + # Aggregate loss + loss_aggregated = weighted_loss_avg( + [ + (evaluate_res.num_examples, evaluate_res.loss) + for _, evaluate_res in results + ] + ) + + + # Aggregate custom metrics if aggregation fn was provided + metrics_aggregated = {} + if self.evaluate_metrics_aggregation_fn: + eval_metrics = [(res.num_examples, res.metrics) for _, res in results] + metrics_aggregated = self.evaluate_metrics_aggregation_fn(eval_metrics) + elif server_round == 1: # Only log this warning once + log(WARNING, "No evaluate_metrics_aggregation_fn provided") + + if loss_aggregated < self.largest_loss: + self.largest_loss = loss_aggregated + np.savez(f"best.npz", *self.aggregated_ndarrays) + + return loss_aggregated, metrics_aggregated + +def server_fn(context: Context): + # Read from config + num_rounds = context.run_config["num-server-rounds"] + fraction_fit = context.run_config["fraction-fit"] + + # Initialize model parameters + ndarrays = get_weights(Net()) + parameters = ndarrays_to_parameters(ndarrays) + + # Define strategy + strategy = SaveModelStrategy( + # fit_metrics_aggregation_fn=weighted_average, + fraction_fit=fraction_fit, + fraction_evaluate=1.0, + min_available_clients=2, + initial_parameters=parameters, + ) + config = ServerConfig(num_rounds=num_rounds) + + return ServerAppComponents(strategy=strategy, config=config) + + +# Create ServerApp +app = ServerApp(server_fn=server_fn) diff --git a/openfl-workspace/flower-app-pytorch/src/app-pytorch/app_pytorch/task.py b/openfl-workspace/flower-app-pytorch/src/app-pytorch/app_pytorch/task.py new file mode 100644 index 0000000000..c17e9fa05a --- /dev/null +++ b/openfl-workspace/flower-app-pytorch/src/app-pytorch/app_pytorch/task.py @@ -0,0 +1,108 @@ +"""app-pytorch: A Flower / PyTorch app.""" + +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torchvision import transforms +from torchvision.datasets import ImageFolder +from torch.utils.data import DataLoader +import os + + +class Net(nn.Module): + """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" + + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + return self.fc3(x) + +def load_partition_data(data_path): + train_data_path = os.path.join(data_path, "train") + test_data_path = os.path.join(data_path, "test") + + # Use ImageFolder to load images from directories + train_data = ImageFolder(root=train_data_path, transform=None) + test_data = ImageFolder(root=test_data_path, transform=None) + + return train_data, test_data + +def load_data(data_path): + """Load partition CIFAR10 data.""" + train_data, test_data = load_partition_data(data_path) + + # Define PyTorch transforms + pytorch_transforms = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + + # Apply transforms to the datasets + train_data.transform = pytorch_transforms + test_data.transform = pytorch_transforms + + # Create DataLoaders + trainloader = DataLoader(train_data, batch_size=32, shuffle=True) + testloader = DataLoader(test_data, batch_size=32) + + return trainloader, testloader + + +def train(net, trainloader, epochs, device): + """Train the model on the training set.""" + net.to(device) # move model to GPU if available + criterion = torch.nn.CrossEntropyLoss().to(device) + optimizer = torch.optim.Adam(net.parameters(), lr=0.01) + net.train() + running_loss = 0.0 + for _ in range(epochs): + for images, labels in trainloader: + optimizer.zero_grad() + loss = criterion(net(images.to(device)), labels.to(device)) + loss.backward() + optimizer.step() + running_loss += loss.item() + + avg_trainloss = running_loss / len(trainloader) + return avg_trainloss + + +def test(net, testloader, device): + """Validate the model on the test set.""" + net.to(device) + criterion = torch.nn.CrossEntropyLoss() + correct, loss = 0, 0.0 + with torch.no_grad(): + for images, labels in testloader: + images, labels = images.to(device), labels.to(device) + outputs = net(images) + loss += criterion(outputs, labels).item() + correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() + accuracy = correct / len(testloader.dataset) + loss = loss / len(testloader) + return loss, accuracy + + +def get_weights(net): + return [val.cpu().numpy() for _, val in net.state_dict().items()] + + +def set_weights(net, parameters): + params_dict = zip(net.state_dict().keys(), parameters) + state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) + net.load_state_dict(state_dict, strict=True) diff --git a/openfl-workspace/flower-app-pytorch/src/app-pytorch/pyproject.toml b/openfl-workspace/flower-app-pytorch/src/app-pytorch/pyproject.toml new file mode 100644 index 0000000000..39e88d12a1 --- /dev/null +++ b/openfl-workspace/flower-app-pytorch/src/app-pytorch/pyproject.toml @@ -0,0 +1,37 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "app-pytorch" +version = "1.0.0" +description = "" +license = "Apache-2.0" +dependencies = [ + "flwr>=1.15.0", + "flwr-datasets[vision]>=0.5.0", + "torch==2.5.1", + "torchvision==0.20.1", +] + +[tool.hatch.build.targets.wheel] +packages = ["."] + +[tool.flwr.app] +publisher = "openfl-dev" + +[tool.flwr.app.components] +serverapp = "app_pytorch.server_app:app" +clientapp = "app_pytorch.client_app:app" + +[tool.flwr.app.config] +num-server-rounds = 3 +fraction-fit = 0.5 +local-epochs = 1 + +[tool.flwr.federations] +default = "local-poc" + +[tool.flwr.federations.local-poc] +address = "127.0.0.1:9093" # this connects to flower --exec-api-address +insecure = true diff --git a/openfl-workspace/flower-app-pytorch/src/patch/__init__.py b/openfl-workspace/flower-app-pytorch/src/patch/__init__.py new file mode 100644 index 0000000000..d5df5b8668 --- /dev/null +++ b/openfl-workspace/flower-app-pytorch/src/patch/__init__.py @@ -0,0 +1,3 @@ +# Copyright (C) 2020-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""You may copy this file as the starting point of your own model.""" diff --git a/openfl-workspace/flower-app-pytorch/src/patch/flower_superlink_patch.py b/openfl-workspace/flower-app-pytorch/src/patch/flower_superlink_patch.py new file mode 100644 index 0000000000..dd6fd0e01e --- /dev/null +++ b/openfl-workspace/flower-app-pytorch/src/patch/flower_superlink_patch.py @@ -0,0 +1,11 @@ +import os +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) + +import src.patch.patch_flwr_telemetry + +import re +from flwr.server.app import run_superlink +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(run_superlink()) diff --git a/openfl-workspace/flower-app-pytorch/src/patch/flower_supernode_patch.py b/openfl-workspace/flower-app-pytorch/src/patch/flower_supernode_patch.py new file mode 100644 index 0000000000..9960d48cc0 --- /dev/null +++ b/openfl-workspace/flower-app-pytorch/src/patch/flower_supernode_patch.py @@ -0,0 +1,11 @@ +import os +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) + +import src.patch.patch_flwr_telemetry + +import re +from flwr.client.supernode.app import run_supernode +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(run_supernode()) \ No newline at end of file diff --git a/openfl-workspace/flower-app-pytorch/src/patch/flwr_run_patch.py b/openfl-workspace/flower-app-pytorch/src/patch/flwr_run_patch.py new file mode 100644 index 0000000000..694169c410 --- /dev/null +++ b/openfl-workspace/flower-app-pytorch/src/patch/flwr_run_patch.py @@ -0,0 +1,11 @@ +import os +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) + +import src.patch.patch_flwr_build + +import re +from flwr.cli.app import app +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(app()) \ No newline at end of file diff --git a/openfl-workspace/flower-app-pytorch/src/patch/patch_flwr_build.py b/openfl-workspace/flower-app-pytorch/src/patch/patch_flwr_build.py new file mode 100644 index 0000000000..ff069cc44e --- /dev/null +++ b/openfl-workspace/flower-app-pytorch/src/patch/patch_flwr_build.py @@ -0,0 +1,149 @@ +import flwr.cli.build +from flwr.cli.build import write_to_zip, get_fab_filename +from typing import Annotated, Optional +import typer +from pathlib import Path +from flwr.cli.utils import is_valid_project_name +from flwr.cli.config_utils import load_and_validate +import tempfile +import zipfile +from flwr.common.constant import FAB_ALLOWED_EXTENSIONS +import shutil +import tomli_w +import hashlib +import os + +def build( + app: Annotated[ + Optional[Path], + typer.Option(help="Path of the Flower App to bundle into a FAB"), + ] = None, +) -> tuple[str, str]: + """Build a Flower App into a Flower App Bundle (FAB). + + You can run ``flwr build`` without any arguments to bundle the app located in the + current directory. Alternatively, you can you can specify a path using the ``--app`` + option to bundle an app located at the provided path. For example: + + ``flwr build --app ./apps/flower-hello-world``. + """ + if app is None: + app = Path.cwd() + + app = app.resolve() + if not app.is_dir(): + typer.secho( + f"❌ The path {app} is not a valid path to a Flower app.", + fg=typer.colors.RED, + bold=True, + ) + raise typer.Exit(code=1) + + if not is_valid_project_name(app.name): + typer.secho( + f"❌ The project name {app.name} is invalid, " + "a valid project name must start with a letter, " + "and can only contain letters, digits, and hyphens.", + fg=typer.colors.RED, + bold=True, + ) + raise typer.Exit(code=1) + + conf, errors, warnings = load_and_validate(app / "pyproject.toml") + if conf is None: + typer.secho( + "Project configuration could not be loaded.\npyproject.toml is invalid:\n" + + "\n".join([f"- {line}" for line in errors]), + fg=typer.colors.RED, + bold=True, + ) + raise typer.Exit(code=1) + + if warnings: + typer.secho( + "Project configuration is missing the following " + "recommended properties:\n" + "\n".join([f"- {line}" for line in warnings]), + fg=typer.colors.RED, + bold=True, + ) + + # Load .gitignore rules if present + ignore_spec = flwr.cli.build._load_gitignore(app) + + list_file_content = "" + + # Remove the 'federations' field from 'tool.flwr' if it exists + if ( + "tool" in conf + and "flwr" in conf["tool"] + and "federations" in conf["tool"]["flwr"] + ): + del conf["tool"]["flwr"]["federations"] + + toml_contents = tomli_w.dumps(conf) + + ### PATCH ### + # REASONING: original code writes to /tmp/ by default. Writing to flwr_home allows us to consolidate written files + # This is useful for running in an SGX enclave with Gramine since we need to strictly control allowed/trusted files + flwr_home = os.getenv("FLWR_HOME") + with tempfile.NamedTemporaryFile(suffix=".zip", dir=flwr_home, delete=False) as temp_file: + ############# + + temp_filename = temp_file.name + + with zipfile.ZipFile(temp_filename, "w", zipfile.ZIP_DEFLATED) as fab_file: + write_to_zip(fab_file, "pyproject.toml", toml_contents) + + # Continue with adding other files + all_files = [ + f + for f in app.rglob("*") + if not ignore_spec.match_file(f) + and f.name != temp_filename + and f.suffix in FAB_ALLOWED_EXTENSIONS + and f.name != "pyproject.toml" # Exclude the original pyproject.toml + ] + ### PATCH ### + # REASONING: order matters for creating a hash. This will force consistent ordering of files + # For SGX, to distribute the FAB pre-experiment, the hash must be consistent on all systems + all_files.sort() + ############# + for file_path in all_files: + # Read the file content manually + with open(file_path, "rb") as f: + file_contents = f.read() + + archive_path = file_path.relative_to(app) + write_to_zip(fab_file, str(archive_path), file_contents) + + # Calculate file info + sha256_hash = hashlib.sha256(file_contents).hexdigest() + file_size_bits = os.path.getsize(file_path) * 8 # size in bits + list_file_content += f"{archive_path},{sha256_hash},{file_size_bits}\n" + + # Add CONTENT and CONTENT.jwt to the zip file + write_to_zip(fab_file, ".info/CONTENT", list_file_content) + + # Get hash of FAB file + content = Path(temp_filename).read_bytes() + fab_hash = hashlib.sha256(content).hexdigest() + + # Set the name of the zip file + fab_filename = get_fab_filename(conf, fab_hash) + + ### PATCH ### + # REASONING: original code writes to /tmp/ by default. Writing to flwr_home allows us to consolidate written files + # Also, return final_path + final_path = os.path.join(flwr_home, fab_filename) + ############# + + shutil.move(temp_filename, final_path) + + typer.secho( + f"🎊 Successfully built {fab_filename}", fg=typer.colors.GREEN, bold=True + ) + + return final_path, fab_hash + + +flwr.cli.build.build = build diff --git a/openfl-workspace/flower-app-pytorch/src/patch/patch_flwr_telemetry.py b/openfl-workspace/flower-app-pytorch/src/patch/patch_flwr_telemetry.py new file mode 100644 index 0000000000..687dae43bd --- /dev/null +++ b/openfl-workspace/flower-app-pytorch/src/patch/patch_flwr_telemetry.py @@ -0,0 +1,50 @@ +import flwr.common.telemetry +from pathlib import Path +import os +import uuid + +def _get_source_id() -> str: + """Get existing or new source ID.""" + source_id = "unavailable" + # Check if .flwr in home exists + + ### PATCH ### + # REASONING: consolidate written file locations + if os.getenv("FLWR_HOME"): + flwr_dir = Path(os.getenv("FLWR_HOME")) + ############# + else: + try: + home = flwr.common.telemetry._get_home() + except RuntimeError: + # If the home directory can’t be resolved, RuntimeError is raised. + return source_id + + flwr_dir = home.joinpath(".flwr") + + # Create .flwr directory if it does not exist yet. + try: + flwr_dir.mkdir(parents=True, exist_ok=True) + except PermissionError: + return source_id + + source_file = flwr_dir.joinpath("source") + + # If no source_file exists create one and write it + if not source_file.exists(): + try: + source_file.touch(exist_ok=True) + source_file.write_text(str(uuid.uuid4()), encoding="utf-8") + except PermissionError: + return source_id + + source_id = source_file.read_text(encoding="utf-8").strip() + + try: + uuid.UUID(source_id) + except ValueError: + source_id = "invalid" + + return source_id + +flwr.common.telemetry._get_source_id = _get_source_id diff --git a/openfl-workspace/flower-app-pytorch/src/setup_data.py b/openfl-workspace/flower-app-pytorch/src/setup_data.py new file mode 100644 index 0000000000..f86ccf8c3d --- /dev/null +++ b/openfl-workspace/flower-app-pytorch/src/setup_data.py @@ -0,0 +1,55 @@ +import os +import sys +from flwr_datasets import FederatedDataset +from flwr_datasets.partitioner import IidPartitioner +from PIL import Image +import numpy as np + +def main(num_partitions): + # Directory to save the partitions + save_dir = "data" + + # Ensure the save directory exists + os.makedirs(save_dir, exist_ok=True) + + # Initialize FederatedDataset + partitioner = IidPartitioner(num_partitions=num_partitions) + fds = FederatedDataset( + dataset="uoft-cs/cifar10", + partitioners={"train": partitioner}, + ) + + # Function to save partition data + def save_partition_data(partition_id, partition_train_test): + partition_dir = os.path.join(save_dir, f"{partition_id+1}") + os.makedirs(partition_dir, exist_ok=True) + + for split, dataset in partition_train_test.items(): + split_dir = os.path.join(partition_dir, split) + os.makedirs(split_dir, exist_ok=True) + + for idx, example in enumerate(dataset): + img_array = np.array(example['img']) + label = example['label'] + label_dir = os.path.join(split_dir, str(label)) + os.makedirs(label_dir, exist_ok=True) + + img = Image.fromarray(img_array) + img_path = os.path.join(label_dir, f"{idx}.png") + img.save(img_path) + + # Download, split, and save the dataset + for partition_id in range(num_partitions): + partition = fds.load_partition(partition_id) + partition_train_test = partition.train_test_split(test_size=0.2, seed=42) + save_partition_data(partition_id, partition_train_test) + + print("Dataset downloaded, split, and saved successfully.") + +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Usage: python setup_data.py ") + sys.exit(1) + + num_partitions = int(sys.argv[1]) + main(num_partitions) \ No newline at end of file diff --git a/openfl-workspace/workspace/plan/defaults/aggregator.yaml b/openfl-workspace/workspace/plan/defaults/aggregator.yaml index 9fc0481f29..730f7c573f 100644 --- a/openfl-workspace/workspace/plan/defaults/aggregator.yaml +++ b/openfl-workspace/workspace/plan/defaults/aggregator.yaml @@ -1,5 +1,9 @@ template : openfl.component.Aggregator settings : db_store_rounds : 2 + write_logs : true + init_state_path : init.pbuf + best_state_path : best.pbuf + last_state_path : last.pbuf persist_checkpoint: True persistent_db_path: local_state/tensor.db diff --git a/openfl-workspace/workspace/plan/defaults/connector.yaml b/openfl-workspace/workspace/plan/defaults/connector.yaml new file mode 100644 index 0000000000..2b6645d22b --- /dev/null +++ b/openfl-workspace/workspace/plan/defaults/connector.yaml @@ -0,0 +1 @@ +template : openfl.component.Connector \ No newline at end of file diff --git a/openfl-workspace/workspace/plan/defaults/tasks_connector.yaml b/openfl-workspace/workspace/plan/defaults/tasks_connector.yaml new file mode 100644 index 0000000000..999da5a8a6 --- /dev/null +++ b/openfl-workspace/workspace/plan/defaults/tasks_connector.yaml @@ -0,0 +1,4 @@ +start_client_adapter: + function : start_client_adapter + kwargs : + local_server_port : 0 # local grpc server, 0 to dynamically allocate diff --git a/openfl/component/__init__.py b/openfl/component/__init__.py index 1224c4c24a..a3fe4eed8c 100644 --- a/openfl/component/__init__.py +++ b/openfl/component/__init__.py @@ -3,6 +3,8 @@ """OpenFL Component Module.""" +from importlib import util + from openfl.component.aggregator.aggregator import Aggregator from openfl.component.aggregator.straggler_handling import ( CutoffTimePolicy, @@ -12,4 +14,9 @@ from openfl.component.assigner.assigner import Assigner from openfl.component.assigner.random_grouped_assigner import RandomGroupedAssigner from openfl.component.assigner.static_grouped_assigner import StaticGroupedAssigner +from openfl.component.assigner.connector_assigner import ConnectorAssigner from openfl.component.collaborator.collaborator import Collaborator +from openfl.component.interoperability.connector import Connector + +if util.find_spec("flwr") is not None: + from openfl.component.interoperability.connector_flower import ConnectorFlower \ No newline at end of file diff --git a/openfl/component/aggregator/aggregator.py b/openfl/component/aggregator/aggregator.py index c04149d035..9b9108449a 100644 --- a/openfl/component/aggregator/aggregator.py +++ b/openfl/component/aggregator/aggregator.py @@ -22,6 +22,7 @@ logger = logging.getLogger(__name__) +import subprocess class Aggregator: """An Aggregator is the central node in federated learning. @@ -74,6 +75,7 @@ def __init__( best_state_path, last_state_path, assigner, + connector, use_delta_updates=True, straggler_handling_policy: StragglerPolicy = CutoffTimePolicy, rounds_to_train=256, @@ -141,6 +143,7 @@ def __init__( self.authorized_cols = authorized_cols self.uuid = aggregator_uuid self.federation_uuid = federation_uuid + self.connector = connector self.quit_job_sent_to = [] @@ -203,8 +206,12 @@ def __init__( tensor_pipe=self.compression_pipeline, ) else: - self.model: base_pb2.ModelProto = utils.load_proto(self.init_state_path) - self._load_initial_tensors() # keys are TensorKeys + if self.connector: + # The model definition will be handled by the respective framework + self.model = {} + else: + self.model: base_pb2.ModelProto = utils.load_proto(self.init_state_path) + self._load_initial_tensors() # keys are TensorKeys self.collaborator_tensor_results = {} # {TensorKey: nparray}} self._secure_aggregation_enabled = secure_aggregation @@ -302,6 +309,7 @@ def _recover(self): task_id += 1 return recovered + def _load_initial_tensors(self): """Load all of the tensors required to begin federated learning. @@ -734,8 +742,8 @@ def send_local_task_results( collaborator_name, round_number, task_name, - data_size, - named_tensors, + data_size=None, + named_tensors=None, ): """ RPC called by collaborator. @@ -799,6 +807,12 @@ def process_task_results( ) return + if self.is_connector_available(): + # Skip to end of round check + with self.lock: + self._is_collaborator_done(collaborator_name, round_number) + self._end_of_round_with_stragglers_check() + task_key = TaskResultKey(task_name, collaborator_name, round_number) # we mustn't have results already @@ -845,6 +859,48 @@ def process_task_results( self._end_of_round_with_stragglers_check() + def is_connector_available(self): + """ + Check if the OpenFL Connector is available. + + Returns: + bool: True if connector is available, False otherwise. + """ + return self.connector is not None + + def start_connector(self): + """ + Start the OpenFL Connector. + + Raises: + RuntimeError: If OpenFL Connector has not been enabled. + """ + if not self.is_connector_available(): + raise RuntimeError("OpenFL Connector has not been enabled.") + return self.connector.start() + + def stop_connector(self): + """ + Stop the OpenFL Connector. + + Raises: + RuntimeError: If OpenFL Connector has not been enabled. + """ + if not self.is_connector_available(): + raise RuntimeError("OpenFL Connector has not been enabled.") + return self.connector.stop() + + def get_local_grpc_client(self): + """ + Get the local gRPC client for the OpenFL Connector. + + Raises: + RuntimeError: If OpenFL Connector has not been enabled. + """ + if not self.is_connector_available(): + raise RuntimeError("OpenFL Connector has not been enabled.") + return self.connector.get_local_grpc_client() + def _end_of_round_with_stragglers_check(self): """ Checks if the minimum required collaborators have reported their results, @@ -1161,20 +1217,22 @@ def _end_of_round_check(self): if self._end_of_round_check_done[self.round_number]: return + if not self.is_connector_available(): # Compute all validation related metrics - logs = {} - for task_name in self.assigner.get_all_tasks_for_round(self.round_number): - logs.update(self._compute_validation_related_task_metrics(task_name)) + logs = {} + for task_name in self.assigner.get_all_tasks_for_round(self.round_number): + logs.update(self._compute_validation_related_task_metrics(task_name)) - # End of round callbacks. - self.callbacks.on_round_end(self.round_number, logs) + # End of round callbacks. + self.callbacks.on_round_end(self.round_number, logs) # Once all of the task results have been processed self._end_of_round_check_done[self.round_number] = True # Save the latest model - logger.info("Saving round %s model...", self.round_number) - self._save_model(self.round_number, self.last_state_path) + if not self.is_connector_available(): + logger.info("Saving round %s model...", self.round_number) + self._save_model(self.round_number, self.last_state_path) self.round_number += 1 # resetting stragglers for task for a new round diff --git a/openfl/component/assigner/__init__.py b/openfl/component/assigner/__init__.py index 18adaab240..0d760d4807 100644 --- a/openfl/component/assigner/__init__.py +++ b/openfl/component/assigner/__init__.py @@ -6,3 +6,4 @@ from openfl.component.assigner.assigner import Assigner from openfl.component.assigner.random_grouped_assigner import RandomGroupedAssigner from openfl.component.assigner.static_grouped_assigner import StaticGroupedAssigner +from openfl.component.assigner.connector_assigner import ConnectorAssigner diff --git a/openfl/component/assigner/connector_assigner.py b/openfl/component/assigner/connector_assigner.py new file mode 100644 index 0000000000..75d1d7d771 --- /dev/null +++ b/openfl/component/assigner/connector_assigner.py @@ -0,0 +1,97 @@ +# Copyright 2020-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +"""Static grouped assigner module.""" + +from openfl.component.assigner.assigner import Assigner + + +class ConnectorAssigner(Assigner): + """The task assigner maintains a list of tasks. + + This assigner is designed to facilitate interoperability between federated learning frameworks. + The expectation is that the OpenFL collaborator is tasked with running the external framework's API. + By default, all collaborators will run the same single task, which is `start_client_adapter` to + start the external framework's client and begin relaying gRPC messages. + + Attributes: + task_groups* (list of object): Task groups to assign. + """ + + def __init__(self, task_groups=None, **kwargs): + """Initializes the ConnectorAssigner. + + Args: + task_groups (list of object): Task groups to assign. + **kwargs: Additional keyword arguments. + """ + self.task_groups = task_groups + super().__init__(**kwargs) + + def define_task_assignments(self): + """Define task assignments for each round and collaborator. + + This method uses the assigner function to assign tasks to + collaborators for each OpenFL round. + """ + if self.task_groups is None: + self.task_groups = [{"name": "default", "tasks": ['start_client_adapter'], "collaborators": self.authorized_cols}] + + for group in self.task_groups: + if "tasks" not in group or not group["tasks"]: + group["tasks"] = ['start_client_adapter'] + if "collaborators" not in group or not group["collaborators"]: + group["collaborators"] = self.authorized_cols + + # Check if any task other than 'start_client_adapter' is present + for task in group["tasks"]: + if task != 'start_client_adapter': + raise ValueError(f"Unsupported task '{task}' found. ConnectorAssigner only supports 'start_client_adapter'.") + + # Start by finding all of the tasks in all specified groups + self.all_tasks_in_groups = list( + {task for group in self.task_groups for task in group["tasks"]} + ) + + # Initialize the map of collaborators for a given task on a given round + for task in self.all_tasks_in_groups: + self.collaborators_for_task[task] = {i: [] for i in range(self.rounds)} + + for group in self.task_groups: + group_col_list = group["collaborators"] + self.task_group_collaborators[group["name"]] = group_col_list + for col in group_col_list: + # For now, we assume that collaborators have the same tasks for + # every round + self.collaborator_tasks[col] = {i: group["tasks"] for i in range(self.rounds)} + # Now populate reverse lookup of tasks->group + for task in group["tasks"]: + for round_ in range(self.rounds): + # This should append the list of collaborators performing + # that task + self.collaborators_for_task[task][round_] += group_col_list + + def get_tasks_for_collaborator(self, collaborator_name, round_number): + """Get tasks for a specific collaborator in a specific round. + + Args: + collaborator_name (str): Name of the collaborator. + round_number (int): Round number. + + Returns: + list: List of tasks for the collaborator in the specified round. + """ + return self.collaborator_tasks[collaborator_name][round_number] + + def get_collaborators_for_task(self, task_name, round_number): + """Get collaborators for a specific task in a specific round. + + Args: + task_name (str): Name of the task. + round_number (int): Round number. + + Returns: + list: List of collaborators for the task in the specified round. + """ + return self.collaborators_for_task[task_name][round_number] \ No newline at end of file diff --git a/openfl/component/collaborator/collaborator.py b/openfl/component/collaborator/collaborator.py index fb0df77131..8b29385bc9 100644 --- a/openfl/component/collaborator/collaborator.py +++ b/openfl/component/collaborator/collaborator.py @@ -15,6 +15,7 @@ from openfl.protocols import utils from openfl.transport.grpc.aggregator_client import AggregatorGRPCClient from openfl.utilities import TensorKey +from openfl.transport.grpc import connector logger = logging.getLogger(__name__) @@ -205,6 +206,25 @@ def do_task(self, task, round_number) -> dict: task_name = task.name func_name = self.task_config[task_name]["function"] kwargs = self.task_config[task_name]["kwargs"] + if func_name=="start_client_adapter": + # TODO: Need to determine a more general way to handle this in order to enable + # additional tasks to be added to be added to Connector + if hasattr(self.task_runner, func_name): + method = getattr(self.task_runner, func_name) + if callable(method): + framework = self.task_config['settings']["connect_to"] + LocalGRPCServer = connector.get_local_grpc_server(framework) + local_grpc_server = LocalGRPCServer(self.client, self.collaborator_name) + method(local_grpc_server, **kwargs) + # TODO: better to use self.send_task_results(global_output_tensor_dict, round_number, task_name) + # maybe set global_output_tensor to empty + self.client.send_local_task_results(self.collaborator_name, round_number, task_name) + metrics = {f'{self.collaborator_name}/start_client_adapter': 'Completed'} + return metrics + else: + raise AttributeError(f"{func_name} is not callable on {self.task_runner}") + else: + raise AttributeError(f"{func_name} does not exist on {self.task_runner}") # this would return a list of what tensors we require as TensorKeys required_tensorkeys_relative = self.task_runner.get_required_tensorkeys_for_function( diff --git a/openfl/component/interoperability/__init__.py b/openfl/component/interoperability/__init__.py new file mode 100644 index 0000000000..0ea8b1f9df --- /dev/null +++ b/openfl/component/interoperability/__init__.py @@ -0,0 +1,8 @@ +# Copyright 2020-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from importlib import util +from openfl.component.interoperability.connector import Connector + +if util.find_spec("flwr") is not None: + from openfl.component.interoperability.connector_flower import ConnectorFlower diff --git a/openfl/component/interoperability/connector.py b/openfl/component/interoperability/connector.py new file mode 100644 index 0000000000..ec371a4387 --- /dev/null +++ b/openfl/component/interoperability/connector.py @@ -0,0 +1,49 @@ +import signal +import sys +from logging import getLogger +from abc import ABC, abstractmethod + +class Connector(ABC): + """ + Abstract base class for managing a server process of an external federated learning framework + and the connection with OpenFL's server. + """ + + def __init__(self, component_name: str = "Base", **kwargs): + """ + Initialize the BaseConnector. + + Args: + command (list[str]): The command to run the server process. + component_name (str): The name of the specific Connector component being used. + """ + self.logger = getLogger(__name__) + self.component_name = component_name + self.local_grpc_client = None + + # Register signal handler for clean termination + signal.signal(signal.SIGINT, self._handle_sigint) + + @abstractmethod + def start(self): + """Start the server process with the provided command.""" + pass + + @abstractmethod + def stop(self): + """Stop the server process if it is running.""" + pass + + def get_local_grpc_client(self): + """Get the local gRPC client.""" + return self.local_grpc_client + + def print_connector_info(self): + """Print information indicating which Connector component is being used.""" + self.logger.info(f"OpenFL Connector Enabled: {self.component_name}") + + def _handle_sigint(self, signum, frame): + """Handle the SIGINT signal (Ctrl+C) to cleanly stop the server process and its children.""" + self.logger.info("[OpenFL Connector] SIGINT received. Terminating server process...") + self.stop() + sys.exit(0) \ No newline at end of file diff --git a/openfl/component/interoperability/connector_flower.py b/openfl/component/interoperability/connector_flower.py new file mode 100644 index 0000000000..c8a2dcb93d --- /dev/null +++ b/openfl/component/interoperability/connector_flower.py @@ -0,0 +1,201 @@ +import subprocess +from openfl.component.interoperability.connector import Connector +from openfl.transport.grpc.connector.flower.local_grpc_client import LocalGRPCClient + +import subprocess +import psutil + +import os +os.environ["FLWR_HOME"] = os.path.join(os.getcwd(), "src/.flwr") +os.makedirs(os.environ["FLWR_HOME"], exist_ok=True) + +class ConnectorFlower(Connector): + """ + Connector subclass for the Flower framework. + Responsible for generating the Flower server command. + """ + + def __init__(self, + superlink_params: dict, + flwr_run_params: dict = None, + automatic_shutdown: bool = True, + **kwargs): + """ + Initialize ConnectorFlower by building the server command. + + Args: + superlink_params (dict): A dictionary of Flower server settings. + flwr_run_params (dict): A dictionary containing the Flower run parameters. + """ + super().__init__(component_name="Flower") + self._process = None + + self.automatic_shutdown = automatic_shutdown + self.signal_shutdown_sent = False + + self.superlink_params = superlink_params + self.flwr_superlink_command = self._build_flwr_superlink_command() + + self.flwr_run_params = flwr_run_params + self.flwr_run_command = self._build_flwr_run_command() if self.flwr_run_params else None + + self.local_grpc_client = self._get_local_grpc_client() + + def _get_local_grpc_client(self): + """ + Create and return a LocalGRPCClient instance based on superlink_params + and the number of server rounds from the pyproject.toml file. + + Returns: + LocalGRPCClient: An instance of LocalGRPCClient initialized with the + connector address and number of server rounds. + """ + connector_address = self.superlink_params.get("fleet-api-address", "0.0.0.0:9092") + return LocalGRPCClient(connector_address, self.automatic_shutdown) + + def _build_flwr_superlink_command(self) -> list[str]: + """ + Build the command to start the Flower SuperLink based on superlink_params. + + Returns: + list[str]: A list representing the Flower server start command. + """ + if self.superlink_params.get("patch"): + command = ["python", "src/patch/flower_superlink_patch.py", "--fleet-api-type", "grpc-adapter"] + else: + command = ["flower-superlink", "--fleet-api-type", "grpc-adapter"] + + if "insecure" in self.superlink_params and self.superlink_params["insecure"]: + command += ["--insecure"] + + if "serverappio-api-address" in self.superlink_params: + command += ["--serverappio-api-address", str(self.superlink_params["serverappio-api-address"])] + # flwr default: 0.0.0.0:9091 + + if "fleet-api-address" in self.superlink_params: + command += ["--fleet-api-address", str(self.superlink_params["fleet-api-address"])] + # flwr default: 0.0.0.0:9092 + + if "exec-api-address" in self.superlink_params: + command += ["--exec-api-address", str(self.superlink_params["exec-api-address"])] + # flwr default: 0.0.0.0:9093 + + if self.automatic_shutdown: + command += ["--isolation", "process"] + self.flwr_serverapp_command = self._build_flwr_serverapp_command() + # flwr will default to "--isolation subprocess" + + return command + + def _build_flwr_serverapp_command(self) -> list[str]: + """ + Build the command to start the Flower ServerApp based on superlink_params. + + Returns: + list[str]: A list representing the Flower server start command. + """ + command = ["flwr-serverapp", "--run-once"] + + if "insecure" in self.superlink_params and self.superlink_params["insecure"]: + command += ["--insecure"] + + if "serverappio-api-address" in self.superlink_params: + command += ["--serverappio-api-address", str(self.superlink_params["serverappio-api-address"])] + + return command + + def is_flwr_serverapp_running(self): + """ + Check if the flwr_serverapp subprocess is still running. + + Returns: + bool: True if the ServerApp is running, False otherwise. + """ + if not hasattr(self, 'flwr_serverapp_subprocess'): + self.logger.debug("[OpenFL Connector] ServerApp was never started.") + return False + + if self.flwr_serverapp_subprocess.poll() is None: + self.logger.debug("[OpenFL Connector] ServerApp is still running.") + return True + + if not self.signal_shutdown_sent: + self.signal_shutdown_sent = True + self.logger.info("[OpenFL Connector] Experiment has ended. Sending signal to shut down Flower components.") + + return False + + def _stop_flwr_serverapp(self): + """Stop the `flwr_serverapp` subprocess if it is still running.""" + if hasattr(self, 'flwr_serverapp_subprocess') and self.flwr_serverapp_subprocess.poll() is None: + self.logger.debug("[OpenFL Connector] ServerApp still running. Stopping...") + self.flwr_serverapp_subprocess.terminate() + try: + self.flwr_serverapp_subprocess.wait(timeout=5) + except subprocess.TimeoutExpired: + self.flwr_serverapp_subprocess.kill() + + def _build_flwr_run_command(self) -> list[str]: + """ + Build the `flwr run` command to run the Flower application. + + Returns: + list[str]: A list representing the flwr_run command. + """ + federation_name = self.flwr_run_params.get("federation_name") + flwr_app_name = self.flwr_run_params.get("flwr_app_name") + + if self.flwr_run_params.get("patch"): + command = ["python", "src/patch/flwr_run_patch.py", "run", f"./src/{flwr_app_name}"] + else: + command = ["flwr", "run", f"./src/{flwr_app_name}"] + + if federation_name: + command.append(federation_name) + + return command + + def start(self): + """Start the `flower-superlink` and `flwr run` subprocesses with the provided commands.""" + if self._process is None: + self.logger.info(f"[OpenFL Connector] Starting server process: {' '.join(self.flwr_superlink_command)}") + self._process = subprocess.Popen(self.flwr_superlink_command) + self.logger.info(f"[OpenFL Connector] Server process started with PID: {self._process.pid}") + else: + self.logger.info("[OpenFL Connector] Server process is already running.") + + if hasattr(self, 'flwr_run_command') and self.flwr_run_command: + self.logger.info(f"[OpenFL Connector] Starting `flwr run` subprocess: {' '.join(self.flwr_run_command)}") + subprocess.run(self.flwr_run_command) + + if hasattr(self, 'flwr_serverapp_command') and self.flwr_serverapp_command: + self.local_grpc_client.set_is_flwr_serverapp_running_callback(self.is_flwr_serverapp_running) + self.flwr_serverapp_subprocess = subprocess.Popen(self.flwr_serverapp_command) + + def stop(self): + """Stop the `flower-superlink` subprocess.""" + self._stop_flwr_serverapp() + if self._process: + try: + self.logger.info(f"[OpenFL Connector] Stopping server process with PID: {self._process.pid}...") + main_process = psutil.Process(self._process.pid) + sub_processes = main_process.children(recursive=True) + for sub_process in sub_processes: + self.logger.info(f"[OpenFL Connector] Stopping server subprocess with PID: {sub_process.pid}...") + sub_process.terminate() + _, still_alive = psutil.wait_procs(sub_processes, timeout=1) + for p in still_alive: + p.kill() + try: + self._process.terminate() + self._process.wait(timeout=5) + except subprocess.TimeoutExpired: + self._process.kill() + self._process = None + self.logger.info("[OpenFL Connector] Server process stopped.") + except Exception as e: + self.logger.debug(f"[OpenFL Connector] Error during graceful shutdown: {e}") + self._process.kill() + self.logger.info("[OpenFL Connector] Server process forcefully terminated.") + else: + self.logger.info("[OpenFL Connector] No server process is currently running.") \ No newline at end of file diff --git a/openfl/federated/__init__.py b/openfl/federated/__init__.py index 72aaac4aea..3afb2f3e4e 100644 --- a/openfl/federated/__init__.py +++ b/openfl/federated/__init__.py @@ -21,6 +21,10 @@ if util.find_spec("xgboost") is not None: from openfl.federated.data import XGBoostDataLoader from openfl.federated.task import XGBoostTaskRunner +if util.find_spec("flwr") is not None: + from openfl.federated.data import FlowerDataLoader + from openfl.federated.task import FlowerTaskRunner + __all__ = [ "Plan", diff --git a/openfl/federated/data/__init__.py b/openfl/federated/data/__init__.py index 53e56a7f7d..29667f7b23 100644 --- a/openfl/federated/data/__init__.py +++ b/openfl/federated/data/__init__.py @@ -16,3 +16,6 @@ if util.find_spec("xgboost") is not None: from openfl.federated.data.loader_xgb import XGBoostDataLoader # NOQA + +if util.find_spec("flwr") is not None: + from openfl.federated.data.loader_flower import FlowerDataLoader # NOQA diff --git a/openfl/federated/data/loader_flower.py b/openfl/federated/data/loader_flower.py new file mode 100644 index 0000000000..3c91c04d6d --- /dev/null +++ b/openfl/federated/data/loader_flower.py @@ -0,0 +1,52 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""FlowerDataLoader module.""" + +from openfl.federated.data.loader import DataLoader + + +class FlowerDataLoader(DataLoader): + """Flower Dataloader + + This class extends the OpenFL DataLoader to provide functionality for + loading and partitioning data for a Flower workload. + + Attributes: + data_shard (int): The shard number of the dataset. + num_partitions (int): The number of partitions to divide the dataset into. + """ + + def __init__(self, data_path, **kwargs): + """ + Initialize the FlowerDataLoader. + + Args: + data_path (str or int): The directory of the dataset. + collaborator_count (int): The number of partitions to divide the dataset into. + **kwargs: Additional keyword arguments to pass to the parent DataLoader class. + + Raises: + ValueError: If collaborator_count is not provided or if data_path is not a number. + """ + super().__init__(**kwargs) + self.data_path = data_path + + def get_node_configs(self): + """ + Get the configuration for each node. + + This method returns the number of partitions and the data shard, + which can be used by each node to access the dataset. + + Returns: + tuple: A tuple containing the number of partitions and the data shard. + """ + return self.data_path + + def get_feature_shape(self): + """ + Override the parent method to return None. + Flower's own infrastructure will handle the feature shape. + """ + return None \ No newline at end of file diff --git a/openfl/federated/plan/plan.py b/openfl/federated/plan/plan.py index 62ea68160e..58f9e28b7d 100644 --- a/openfl/federated/plan/plan.py +++ b/openfl/federated/plan/plan.py @@ -273,6 +273,7 @@ def __init__(self): self.collaborator_ = None # collaborator object self.aggregator_ = None # aggregator object self.assigner_ = None # assigner object + self.connector_ = None # OpenFL Connector object self.loader_ = None # data loader object self.runner_ = None # task runner object @@ -329,6 +330,16 @@ def get_assigner(self): self.assigner_ = Plan.build(**defaults) return self.assigner_ + + def get_connector(self): + """Get OpenFL Connector object.""" + defaults = self.config.get("connector") + logger.info("Connector defaults: %s", defaults) + + if self.connector_ is None and defaults: + self.connector_ = Plan.build(**defaults) + + return self.connector_ def get_tasks(self): """Get federation tasks.""" @@ -381,6 +392,7 @@ def get_aggregator(self, tensor_dict=None): defaults[SETTINGS]["assigner"] = self.get_assigner() defaults[SETTINGS]["compression_pipeline"] = self.get_tensor_pipe() defaults[SETTINGS]["straggler_handling_policy"] = self.get_straggler_handling_policy() + defaults[SETTINGS]["connector"] = self.get_connector() # TODO: Load callbacks from plan. @@ -456,9 +468,7 @@ def get_task_runner(self, data_loader): if self.runner_ is None: self.runner_ = Plan.build(**defaults) - # Define task dependencies after taskrunner has been initialized self.runner_.initialize_tensorkeys_for_functions() - return self.runner_ def get_collaborator( diff --git a/openfl/federated/task/__init__.py b/openfl/federated/task/__init__.py index 7d1d7dfaeb..1763b3c54d 100644 --- a/openfl/federated/task/__init__.py +++ b/openfl/federated/task/__init__.py @@ -14,3 +14,5 @@ from openfl.federated.task.runner_pt import PyTorchTaskRunner # NOQA if util.find_spec("xgboost") is not None: from openfl.federated.task.runner_xgb import XGBoostTaskRunner # NOQA +if util.find_spec("flwr") is not None: + from openfl.federated.task.runner_flower import FlowerTaskRunner # NOQA diff --git a/openfl/federated/task/runner.py b/openfl/federated/task/runner.py index 6689041a3a..70fe8bd51b 100644 --- a/openfl/federated/task/runner.py +++ b/openfl/federated/task/runner.py @@ -39,7 +39,10 @@ def __init__(self, data_loader, tensor_dict_split_fn_kwargs: dict = None, **kwar **kwargs: Additional parameters to pass to the function. """ self.data_loader = data_loader - self.feature_shape = self.data_loader.get_feature_shape() + if self.data_loader: + self.feature_shape = self.data_loader.get_feature_shape() + else: + self.feature_shape = None # TODO: Should this comment a path of the doc string? # key word arguments for determining which parameters to hold out from # aggregation. diff --git a/openfl/federated/task/runner_flower.py b/openfl/federated/task/runner_flower.py new file mode 100644 index 0000000000..ae58e05630 --- /dev/null +++ b/openfl/federated/task/runner_flower.py @@ -0,0 +1,187 @@ +from openfl.federated.task.runner import TaskRunner +import subprocess +from logging import getLogger +import time +import os +import numpy as np +from pathlib import Path +import sys +import socket + +os.environ["FLWR_HOME"] = os.path.join(os.getcwd(), "src/.flwr") +os.makedirs(os.environ["FLWR_HOME"], exist_ok=True) + +class FlowerTaskRunner(TaskRunner): + """ + FlowerTaskRunner is a task runner that executes Flower SuperNode + to initialize the experiment from the client side. + + This class is responsible for starting a local gRPC server and a Flower SuperNode + in a subprocess. It also provides options for automatic shutdown based on subprocess + activity. + + Shutdown Options: + - Manual Shutdown: The server and supernode process can be manually stopped by pressing CTRL+C. + - Automatic Shutdown: If enabled, the system will monitor the activity of subprocesses and + automatically shut down if no new subprocess starts within a certain time frame. + """ + def __init__(self, **kwargs): + """ + Initializes the FlowerTaskRunner. + + Args: + auto_shutdown (bool): Whether to enable automatic shutdown based on subprocess activity. + Default is True. Set to False for long-lived components. + **kwargs: Additional parameters to pass to the functions. + """ + super().__init__(**kwargs) + + self.patch = kwargs.get('patch') + if self.data_loader is None: + flwr_app_name = kwargs.get('flwr_app_name') + if self.patch: + install_flower_FAB(flwr_app_name) + return + + self.model = None + self.logger = getLogger(__name__) + + self.data_path = self.data_loader.get_node_configs() + + self.client_port = kwargs.get('client_port') + if self.client_port is None: + self.client_port = get_dynamic_port() + + self.shutdown_requested = False # Flag signal shutdown + + def start_client_adapter(self, local_grpc_server, **kwargs): + """ + Starts the local gRPC server and the Flower SuperNode. + """ + local_server_port = kwargs.get('local_server_port') + + def message_callback(): + self.shutdown_requested = True + + # TODO: Can we isolate the local_grpc_server from the task runner? + local_grpc_server.set_end_experiment_callback(message_callback) + local_grpc_server.start_server(local_server_port) + + local_server_port = local_grpc_server.get_port() + + if self.patch: + command = [ + "python", + "src/patch/flower_supernode_patch.py", + "--insecure", + "--grpc-adapter", + "--superlink", f"127.0.0.1:{local_server_port}", + "--clientappio-api-address", f"127.0.0.1:{self.client_port}", + "--node-config", f"data-path='{self.data_path}'" + ] + else: + command = [ + "flower-supernode", + "--insecure", + "--grpc-adapter", + "--superlink", f"127.0.0.1:{local_server_port}", + "--clientappio-api-address", f"127.0.0.1:{self.client_port}", + "--node-config", f"data-path='{self.data_path}'" + ] + + supernode_process = subprocess.Popen(command, shell=False) + local_grpc_server.handle_signals(supernode_process) + + self.logger.info("Press CTRL+C to stop the server and SuperNode process.") + + try: + while not local_grpc_server.termination_event.is_set(): + if self.shutdown_requested: + local_grpc_server.terminate_supernode_process(supernode_process) + local_grpc_server.stop_server() + time.sleep(0.1) + except KeyboardInterrupt: + local_grpc_server.terminate_supernode_process(supernode_process) + local_grpc_server.stop_server() + + def set_tensor_dict(self, tensor_dict, with_opt_vars=False): + """Set the tensor dictionary. + To be framework agnostic, this method will not attempt to load the weights into the model + and save out the native format. Instead, it will load and save the dictionary directly + + Args: + tensor_dict (dict): The tensor dictionary. + with_opt_vars (bool): This argument is inherited from the parent class + but is not used in the FlowerTaskRunner. + """ + self.tensor_dict = tensor_dict + + def save_native( + self, + filepath, + **kwargs, + ): + """ + Save model weights in a .npz file specified by the filepath. + The model weights are stored as a dictionary of np.ndarray + + Args: + filepath (str): Path to the .npz file to be created by np.savez(). + **kwargs: Additional parameters (currently not used). + + Returns: + None + + Raises: + AssertionError: If the file extension is not '.npz'. + """ + # Ensure the file extension is .npz + if isinstance(filepath, Path): + filepath = str(filepath) + + # Ensure the file extension is .npz + assert filepath.endswith('.npz'), "Currently, only '.npz' file type is supported." + + # Save the tensor dictionary to a .npz file + np.savez(filepath, **self.tensor_dict) + + def initialize_tensorkeys_for_functions(self, with_opt_vars=False): + pass + + +def install_flower_FAB(flwr_app_name): + """Build and install the patch for the Flower application.""" + flwr_dir = os.environ["FLWR_HOME"] + + # Run the build command + subprocess.check_call([ + sys.executable, + "src/patch/flwr_run_patch.py", + "build", + "--app", + f"./src/{flwr_app_name}" + ]) + + # List .fab files after running the build command + fab_files = list(Path(flwr_dir).glob("*.fab")) + + # Determine the newest .fab file + newest_fab_file = max(fab_files, key=os.path.getmtime) + + # Run the install command using the newest .fab file + subprocess.check_call([ + sys.executable, + "src/patch/flwr_run_patch.py", + "install", + str(newest_fab_file) + ]) + +def get_dynamic_port(): + # Create a socket + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + # Bind to port 0 to let the OS assign an available port + s.bind(('', 0)) + # Get the assigned port number + port = s.getsockname()[1] + return port + \ No newline at end of file diff --git a/openfl/interface/plan.py b/openfl/interface/plan.py index e39abdbb07..4e7729ce86 100644 --- a/openfl/interface/plan.py +++ b/openfl/interface/plan.py @@ -111,7 +111,7 @@ def initialize( Initializes a Data Science plan and generates a protobuf file of the initial model weights for the federation. """ - + for p in [plan_config, cols_config, data_config]: if is_directory_traversal(p): echo(f"{p} is out of the openfl workspace scope.") @@ -130,38 +130,44 @@ def initialize( gandlf_config_path=gandlf_config, ) - init_state_path = plan.config["aggregator"]["settings"]["init_state_path"] - # This is needed to bypass data being locally available - if input_shape is not None: - logger.info( - f"Attempting to generate initial model weights with custom input shape {input_shape}" + if 'connector' in plan.config: + logger.info("OpenFL Connector enabled: %s", plan.config['connector']) + # Only need to initialize task runner to install apps/packages + # that were not installable via requirements.txt + plan.get_task_runner(data_loader=None) + else: + init_state_path = plan.config["aggregator"]["settings"]["init_state_path"] + # This is needed to bypass data being locally available + if input_shape is not None: + logger.info( + f"Attempting to generate initial model weights with custom input shape {input_shape}" + ) + + # Initialize tensor dictionary + init_tensor_dict, task_runner, round_number = _initialize_tensor_dict( + plan, input_shape, init_model_path ) - # Initialize tensor dictionary - init_tensor_dict, task_runner, round_number = _initialize_tensor_dict( - plan, input_shape, init_model_path - ) - - tensor_dict, holdout_params = split_tensor_dict_for_holdouts( - init_tensor_dict, - **task_runner.tensor_dict_split_fn_kwargs, - ) - - logger.warning( - f"Following parameters omitted from global initial model, " - f"local initialization will determine" - f" values: {list(holdout_params.keys())}" - ) + tensor_dict, holdout_params = split_tensor_dict_for_holdouts( + init_tensor_dict, + **task_runner.tensor_dict_split_fn_kwargs, + ) - # Save the model state - try: - logger.info(f"Saving model state to {init_state_path}") - plan.save_model_to_state_file( - tensor_dict=tensor_dict, round_number=round_number, output_path=init_state_path + logger.warning( + f"Following parameters omitted from global initial model, " + f"local initialization will determine" + f" values: {list(holdout_params.keys())}" ) - except Exception as e: - logger.error(f"Failed to save model state: {e}") - raise + + # Save the model state + try: + logger.info(f"Saving model state to {init_state_path}") + plan.save_model_to_state_file( + tensor_dict=tensor_dict, round_number=round_number, output_path=init_state_path + ) + except Exception as e: + logger.error(f"Failed to save model state: {e}") + raise plan_origin = Plan.parse( plan_config_path=plan_config, diff --git a/openfl/protocols/aggregator.proto b/openfl/protocols/aggregator.proto index 7a048eaf76..71e216dd94 100644 --- a/openfl/protocols/aggregator.proto +++ b/openfl/protocols/aggregator.proto @@ -12,6 +12,7 @@ service Aggregator { rpc GetTasks(GetTasksRequest) returns (GetTasksResponse) {} rpc GetAggregatedTensor(GetAggregatedTensorRequest) returns (GetAggregatedTensorResponse) {} rpc SendLocalTaskResults(stream DataStream) returns (SendLocalTaskResultsResponse) {} + rpc PelicanDrop(DropPod) returns (DropPod) {} } message MessageHeader { @@ -69,3 +70,9 @@ message TaskResults { message SendLocalTaskResultsResponse { MessageHeader header = 1; } + +message DropPod { + MessageHeader header = 1; + DataStream message = 2; + map metadata = 3; +} diff --git a/openfl/transport/grpc/aggregator_client.py b/openfl/transport/grpc/aggregator_client.py index ffd95f40ce..c4f3fdc925 100644 --- a/openfl/transport/grpc/aggregator_client.py +++ b/openfl/transport/grpc/aggregator_client.py @@ -16,6 +16,7 @@ logger = logging.getLogger(__name__) + class ConstantBackoff: """Constant Backoff policy. @@ -371,8 +372,8 @@ def send_local_task_results( collaborator_name, round_number, task_name, - data_size, - named_tensors, + data_size=None, + named_tensors=None, ): """ Send task results to the aggregator. @@ -402,3 +403,27 @@ def send_local_task_results( # convert (potentially) long list of tensors into stream response = self.stub.SendLocalTaskResults(utils.proto_to_datastream(request)) self.validate_response(response, collaborator_name) + + @_atomic_connection + @_resend_data_on_reconnection + def send_message_to_server(self, openfl_message, collaborator_name): + """ + Forwards a converted message from the local GRPC server (LGS) to the OpenFL server and returns the response. + + Args: + openfl_message: The converted message from the LGS to be sent to the OpenFL server. + collaborator_name: The name of the collaborator. + + Returns: + The response from the OpenFL server + """ + header = create_header( + sender=collaborator_name, + receiver=self.aggregator_uuid, + federation_uuid=self.federation_uuid, + single_col_cert_common_name=self.single_col_cert_common_name, + ) + openfl_message.header.CopyFrom(header) + openfl_response = self.stub.PelicanDrop(openfl_message) + self.validate_response(openfl_response, collaborator_name) + return openfl_response diff --git a/openfl/transport/grpc/aggregator_server.py b/openfl/transport/grpc/aggregator_server.py index 46b8a840a0..f71af9322c 100644 --- a/openfl/transport/grpc/aggregator_server.py +++ b/openfl/transport/grpc/aggregator_server.py @@ -54,6 +54,13 @@ def __init__( self.root_certificate = root_certificate self.certificate = certificate self.private_key = private_key + self.use_connector = self.aggregator.is_connector_available() + + if self.use_connector: + self.local_grpc_client = self.aggregator.get_local_grpc_client() # Initialize the local gRPC client + else: + self.local_grpc_client = None + self.root_certificate_refresher_cb = root_certificate_refresher_cb def validate_collaborator(self, request, context): @@ -196,6 +203,9 @@ def GetAggregatedTensor(self, request, context): # NOQA:N802 aggregator_pb2.GetAggregatedTensorResponse: The response to the request. """ + if self.use_connector: + context.abort(StatusCode.UNIMPLEMENTED, "This method is not available in framework interopability mode.") + self.validate_collaborator(request, context) self.check_request(request) collaborator_name = request.header.sender @@ -242,6 +252,9 @@ def SendLocalTaskResults(self, request, context): # NOQA:N802 aggregator_pb2.SendLocalTaskResultsResponse: The response to the request. """ + # if self.use_connector: + # context.abort(StatusCode.UNIMPLEMENTED, "This method is not available in framework interopability mode.") + try: proto = aggregator_pb2.TaskResults() proto = utils.datastream_to_proto(proto, request) @@ -269,10 +282,43 @@ def SendLocalTaskResults(self, request, context): # NOQA:N802 federation_uuid=self.aggregator.federation_uuid, single_col_cert_common_name=self.aggregator.single_col_cert_common_name, ) + return aggregator_pb2.SendLocalTaskResultsResponse(header=header) + def PelicanDrop(self, request, context): + """ + Args: + request (aggregator_pb2.PelicanDrop): The request + from the collaborator. + context (grpc.ServicerContext): The context of the request. + + Returns: + aggregator_pb2.PelicanDrop: The response to the + request. + """ + if not self.use_connector: + context.abort(StatusCode.UNIMPLEMENTED, "PelicanDrop is only available in federated interopability mode.") + + self.validate_collaborator(request, context) + self.check_request(request) + collaborator_name = request.header.sender + + header = create_header( + sender=self.aggregator.uuid, + receiver=collaborator_name, + federation_uuid=self.aggregator.federation_uuid, + single_col_cert_common_name=self.aggregator.single_col_cert_common_name, + ) + + # Forward the incoming OpenFL message to the local gRPC client + return self.local_grpc_client.send_receive(request, header=header) + def serve(self): """Starts the aggregator gRPC server.""" + + if self.use_connector: + self.aggregator.start_connector() + server = create_grpc_server( self.uri, self.use_tls, @@ -290,4 +336,7 @@ def serve(self): while not self.aggregator.all_quit_jobs_sent(): sleep(5) + if self.use_connector: + self.aggregator.stop_connector() + server.stop(0) diff --git a/openfl/transport/grpc/connector/__init__.py b/openfl/transport/grpc/connector/__init__.py new file mode 100644 index 0000000000..8bc7eca6d9 --- /dev/null +++ b/openfl/transport/grpc/connector/__init__.py @@ -0,0 +1,3 @@ +from openfl.transport.grpc.connector.utils import get_local_grpc_server + +__all__ = ['get_local_grpc_server'] \ No newline at end of file diff --git a/openfl/transport/grpc/connector/flower/__init__.py b/openfl/transport/grpc/connector/flower/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/openfl/transport/grpc/connector/flower/deserialize_message.py b/openfl/transport/grpc/connector/flower/deserialize_message.py new file mode 100644 index 0000000000..46f37e8acc --- /dev/null +++ b/openfl/transport/grpc/connector/flower/deserialize_message.py @@ -0,0 +1,41 @@ +import importlib +from google.protobuf.message import DecodeError + +def deserialize_flower_message(flower_message): + """ + Deserialize the grpc_message_content of a Flower message using the module and class name + specified in the metadata. + + Args: + flower_message: The Flower message containing the metadata and binary content. + + Returns: + The deserialized message object, or None if deserialization fails. + """ + # Access metadata directly + metadata = flower_message.metadata + module_name = metadata.get('grpc-message-module') + qualname = metadata.get('grpc-message-qualname') + + # Import the module + try: + module = importlib.import_module(module_name) + except ImportError as e: + print(f"Failed to import module: {module_name}. Error: {e}") + return None + + # Get the message class + try: + message_class = getattr(module, qualname) + except AttributeError as e: + print(f"Failed to get message class '{qualname}' from module '{module_name}'. Error: {e}") + return None + + # Deserialize the content + try: + message = message_class.FromString(flower_message.grpc_message_content) + except DecodeError as e: + print(f"Failed to deserialize message content. Error: {e}") + return None + + return message \ No newline at end of file diff --git a/openfl/transport/grpc/connector/flower/local_grpc_client.py b/openfl/transport/grpc/connector/flower/local_grpc_client.py new file mode 100644 index 0000000000..90d7618be0 --- /dev/null +++ b/openfl/transport/grpc/connector/flower/local_grpc_client.py @@ -0,0 +1,54 @@ +import grpc +from flwr.proto import grpcadapter_pb2_grpc +from openfl.transport.grpc.connector.flower.message_conversion import flower_to_openfl_message, openfl_to_flower_message +from logging import getLogger + +class LocalGRPCClient: + """ + LocalGRPCClient facilitates communication between the Flower SuperLink + and the OpenFL Server. It converts messages between OpenFL and Flower formats + and handles the send-receive communication with the Flower SuperNode using gRPC. + """ + def __init__(self, superlink_address, automatic_shutdown=False): + """ + Initialize. + + Args: + superlink_address: The address the Flower SuperLink will listen on + """ + self.superlink_channel = grpc.insecure_channel(superlink_address) + self.superlink_stub = grpcadapter_pb2_grpc.GrpcAdapterStub(self.superlink_channel) + + self.automatic_shutdown = automatic_shutdown + self.end_experiment = False + self.is_flwr_serverapp_running_callback = None + + self.logger = getLogger(__name__) + + def set_is_flwr_serverapp_running_callback(self, is_flwr_serverapp_running_callback): + self.is_flwr_serverapp_running_callback = is_flwr_serverapp_running_callback + + def send_receive(self, openfl_message, header): + """ + Sends a message to the Flower SuperLink and receives the response. + + Args: + openfl_message: converted Flower SuperNode request sent by OpenFL server + header: OpenFL header information to be included in the message. + + Returns: + The response from the Flower SuperLink, converted back to OpenFL format. + """ + flower_message = openfl_to_flower_message(openfl_message) + flower_response = self.superlink_stub.SendReceive(flower_message) + + if self.automatic_shutdown and self.is_flwr_serverapp_running_callback: + # Check if the flwr_serverapp subprocess is still running, if it isn't + # then the experiment has completed + self.end_experiment = not self.is_flwr_serverapp_running_callback() + + openfl_response = flower_to_openfl_message(flower_response, + header=header, + end_experiment=self.end_experiment) + + return openfl_response diff --git a/openfl/transport/grpc/connector/flower/local_grpc_server.py b/openfl/transport/grpc/connector/flower/local_grpc_server.py new file mode 100644 index 0000000000..85314948c1 --- /dev/null +++ b/openfl/transport/grpc/connector/flower/local_grpc_server.py @@ -0,0 +1,141 @@ +import logging +import threading +import queue +import grpc +from concurrent.futures import ThreadPoolExecutor +from flwr.proto import grpcadapter_pb2_grpc +from openfl.transport.grpc.connector.flower.message_conversion import flower_to_openfl_message, openfl_to_flower_message +from multiprocessing import cpu_count +import signal +import psutil +import time + +logger = logging.getLogger(__name__) + + +class LocalGRPCServer(grpcadapter_pb2_grpc.GrpcAdapterServicer): + """ + LocalGRPCServer is a gRPC server that handles requests from the Flower SuperNode + and forwards them to the OpenFL Client. It uses a queue-based system to + ensure that requests are processed sequentially, preventing concurrent + request handling issues. + """ + + def __init__(self, openfl_client, collaborator_name): + """ + Initialize. + + Args: + openfl_client: An instance of the OpenFL Client. + collaborator_name: The name of the collaborator. + """ + self.openfl_client = openfl_client + self.collaborator_name = collaborator_name + self.end_experiment_callback = None + self.request_queue = queue.Queue() + self.processing_thread = threading.Thread(target=self.process_queue) + self.processing_thread.daemon = True + self.processing_thread.start() + self.server = None + self.termination_event = threading.Event() + self.port = None + + def set_end_experiment_callback(self, callback): + self.end_experiment_callback = callback + + def start_server(self, local_server_port): + """Starts the gRPC server.""" + self.server = grpc.server(ThreadPoolExecutor(max_workers=cpu_count())) + grpcadapter_pb2_grpc.add_GrpcAdapterServicer_to_server(self, self.server) + self.port = self.server.add_insecure_port(f'[::]:{local_server_port}') + self.server.start() + logger.info(f"OpenFL local gRPC server started, listening on port {self.port}.") + + def get_port(self): + # Return the port that was assigned + return self.port + + def stop_server(self): + """Stops the gRPC server.""" + if self.server: + logger.info("Shutting down local gRPC server...") + self.server.stop(0) + logger.info("local gRPC server stopped.") + self.termination_event.set() + + def SendReceive(self, request, context): + """ Handles incoming gRPC requests by putting them into the request queue and waiting for the response. + Args: + request: The incoming gRPC request. + context: The gRPC context. + Returns: + The response from the OpenFL server. + """ + response_queue = queue.Queue() + self.request_queue.put((request, response_queue)) + return response_queue.get() + + def process_queue(self): + """ + Continuously processes requests from the request queue. Each request is + sent to the OpenFL server, and the response is put into the corresponding + response queue. + """ + while True: + request, response_queue = self.request_queue.get() + openfl_request = flower_to_openfl_message(request) + + # Send request to the OpenFL server + openfl_response = self.openfl_client.send_message_to_server(openfl_request, self.collaborator_name) + + # Check to end experiment + if hasattr(openfl_response, 'metadata'): + if openfl_response.metadata['end_experiment'] == 'True': + self.end_experiment_callback() + + # Send response to Flower client + flower_response = openfl_to_flower_message(openfl_response) + response_queue.put(flower_response) + self.request_queue.task_done() + + def handle_signals(self, supernode_process): + """Sets up signal handlers for graceful shutdown.""" + def signal_handler(_sig, _frame): + self.terminate_supernode_process(supernode_process) + self.stop_server() + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + def terminate_supernode_process(self, supernode_process): + """Terminates the SuperNode process.""" + if supernode_process.poll() is None: + try: + main_subprocess = psutil.Process(supernode_process.pid) + client_app_processes = main_subprocess.children(recursive=True) + + for client_app_process in client_app_processes: + self.terminate_process(client_app_process) + + self.terminate_process(main_subprocess) + logger.info("SuperNode process terminated.") + + except Exception as e: + logger.debug(f"Error during graceful shutdown: {e}") + time.sleep(10) + supernode_process.kill() + logger.info("SuperNode process terminated.") + else: + logger.info("SuperNode process already terminated.") + + def terminate_process(self, process, timeout=5): + """Helper function to terminate a process gracefully.""" + try: + process.terminate() + process.wait(timeout=timeout) + except psutil.TimeoutExpired: + logger.debug(f"Timeout expired while waiting for process {process.pid} to terminate. Killing the process.") + process.kill() + except psutil.NoSuchProcess: + logger.debug(f"Process {process.pid} does not exist. Skipping.") + pass \ No newline at end of file diff --git a/openfl/transport/grpc/connector/flower/message_conversion.py b/openfl/transport/grpc/connector/flower/message_conversion.py new file mode 100644 index 0000000000..ded2235659 --- /dev/null +++ b/openfl/transport/grpc/connector/flower/message_conversion.py @@ -0,0 +1,65 @@ +from flwr.proto import grpcadapter_pb2 +from openfl.protocols import aggregator_pb2 + +def flower_to_openfl_message(flower_message, + header=None, + end_experiment=False): + """ + Convert a Flower MessageContainer to an OpenFL DropPod. + + This function takes a Flower MessageContainer and converts it into an OpenFL DropPod. + If the input is already an OpenFL DropPod, it returns the input as-is. + + Args: + flower_message (grpcadapter_pb2.MessageContainer or aggregator_pb2.DropPod): + The Flower message to be converted. It can either be a Flower MessageContainer + or an OpenFL DropPod. + header (aggregator_pb2.MessageHeader, optional): + An optional header to be included in the OpenFL DropPod. If provided, + it will be copied to the DropPod's header field. + + Returns: + aggregator_pb2.DropPod: The converted OpenFL DropPod message. + """ + if isinstance(flower_message, aggregator_pb2.DropPod): + # If the input is already an OpenFL message, return it as-is + return flower_message + else: + # Create the OpenFL message + openfl_message = aggregator_pb2.DropPod() + # Set the MessageHeader fields based on the provided sender and receiver + if header: + openfl_message.header.CopyFrom(header) + + # Serialize the Flower message and set it in the OpenFL message + serialized_flower_message = flower_message.SerializeToString() + openfl_message.message.npbytes = serialized_flower_message + openfl_message.message.size = len(serialized_flower_message) + + # Add flag to check if experiment has ended + openfl_message.metadata.update({"end_experiment": str(end_experiment)}) + return openfl_message + +def openfl_to_flower_message(openfl_message): + """ + Convert an OpenFL DropPod to a Flower MessageContainer. + + This function takes an OpenFL DropPod and converts it into a Flower MessageContainer. + If the input is already a Flower MessageContainer, it returns the input as-is. + + Args: + openfl_message (aggregator_pb2.DropPod or grpcadapter_pb2.MessageContainer): + The OpenFL message to be converted. It can either be an OpenFL DropPod + or a Flower MessageContainer. + + Returns: + grpcadapter_pb2.MessageContainer: The converted Flower MessageContainer. + """ + if isinstance(openfl_message, grpcadapter_pb2.MessageContainer): + # If the input is already a Flower message, return it as-is + return openfl_message + else: + # Deserialize the Flower message from the DataStream npbytes field + flower_message = grpcadapter_pb2.MessageContainer() + flower_message.ParseFromString(openfl_message.message.npbytes) + return flower_message \ No newline at end of file diff --git a/openfl/transport/grpc/connector/utils.py b/openfl/transport/grpc/connector/utils.py new file mode 100644 index 0000000000..36d8b75c6b --- /dev/null +++ b/openfl/transport/grpc/connector/utils.py @@ -0,0 +1,10 @@ +import importlib + +def get_local_grpc_server(framework: str = 'Flower') -> object: + if framework == 'Flower': + try: + module = importlib.import_module('openfl.transport.grpc.connector.flower.local_grpc_server') + return module.LocalGRPCServer + except ImportError: + print("Flower is not installed.") + return None \ No newline at end of file