diff --git a/.coverage b/.coverage index db6cd5276e..d4dfa752f9 100644 Binary files a/.coverage and b/.coverage differ diff --git a/.github/actions/tr_post_test_run/action.yml b/.github/actions/tr_post_test_run/action.yml index d227ca45de..3a0ca6c42a 100644 --- a/.github/actions/tr_post_test_run/action.yml +++ b/.github/actions/tr_post_test_run/action.yml @@ -24,7 +24,7 @@ runs: id: tar_files if: ${{ always() }} run: | - tar -cvf result.tar --exclude="cert" --exclude="data" --exclude="__pycache__" --exclude="tensor.db" --exclude="workspace.tar" $HOME/results + tar -cvf result.tar --exclude="cert" --exclude="data" --exclude="__pycache__" --exclude="tensor.db" --exclude="workspace.tar" --exclude="minio_data" $HOME/results # Model name might contain forward slashes, convert them to underscore. tmp=${{ env.MODEL_NAME }} echo "MODEL_NAME_MODIFIED=${tmp//\//_}" >> $GITHUB_ENV diff --git a/.github/scripts/.coverage b/.github/scripts/.coverage new file mode 100644 index 0000000000..0e23bd0216 Binary files /dev/null and b/.github/scripts/.coverage differ diff --git a/.github/scripts/coverage-report.sh b/.github/scripts/coverage-report.sh new file mode 100755 index 0000000000..885d5a3426 --- /dev/null +++ b/.github/scripts/coverage-report.sh @@ -0,0 +1,60 @@ +#!/bin/bash + +set -Eeuo pipefail + +rm -rf build + +pip uninstall openfl -y + +pip install -e . + +pip install -r test-requirements.txt + +pip install -r openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt + +pip install coverage + +pip install pytest-cov + +fx experimental deactivate + +rm -rf .coverage + +python -m pytest -rA --cov=openfl + +python -m pytest -s tests/end_to_end/test_suites/task_runner_tests.py -k test_federation_via_native --model_name torch/mnist --num_rounds 2 --disable_client_auth --secure_agg --cov-report=term-missing --cov-append --cov=openfl + +python -m pytest -s tests/end_to_end/test_suites/task_runner_tests.py -k test_federation_via_native --model_name keras/jax/mnist --num_rounds 2 --disable_tls --cov-report=term-missing --cov-append --cov=openfl + +python -m pytest -s tests/end_to_end/test_suites/memory_logs_tests.py -k test_log_memory_usage_basic --model_name torch/histology --num_rounds 2 --log_memory_usage --secure_agg --cov-report=term-missing --cov-append --cov=openfl + +python -m pytest -s tests/end_to_end/test_suites/tr_resiliency_tests.py --model_name keras/torch/mnist --num_rounds 25 --cov-report=term-missing --cov-append --cov=openfl + +python -m pytest -s tests/end_to_end/test_suites/tr_flower_tests.py -k test_flower_app_pytorch_native --model_name flower-app-pytorch --num_rounds 1 --cov-report=term-missing --cov-append --cov=openfl + + +python -m pytest -s tests/end_to_end/test_suites/task_runner_tests.py -m task_runner_dockerized_ws --num_rounds 2 --model_name keras/torch/mnist --cov-report=term-missing --cov-append --cov=openfl + +python -m pytest -s tests/end_to_end/test_suites/tr_with_fedeval_tests.py -m task_runner_basic --model_name xgb_higgs --num_rounds 2 --cov-report=term-missing --cov-append --cov=openfl + +python -m pytest -s tests/end_to_end/test_suites/wf_local_func_tests.py --num_rounds 2 --cov-report=term-missing --cov-append --cov=openfl + +python -m pytest -s tests/end_to_end/test_suites/wf_local_func_tests.py --workflow_backend ray --num_rounds 2 --cov-report=term-missing --cov-append --cov=openfl + +fx experimental activate + +python -m pytest -s tests/end_to_end/test_suites/wf_federated_runtime_tests.py -k test_federated_runtime_301_watermarking --cov-report=term-missing --cov-append --cov=openfl + +python -m pytest -s tests/end_to_end/test_suites/wf_federated_runtime_tests.py -k test_federated_runtime_secure_aggregation --cov-report=term-missing --cov-append --cov=openfl + +# python -m pytest -s tests/end_to_end/test_suites/wf_federated_runtime_tests.py -k test_federated_evaluation --cov-report=term-missing --cov-append --cov=openfl + +fx experimental deactivate +# Combine and generate the final coverage report +coverage report + + + + + + diff --git a/.github/workflows/pq_pipeline.yml b/.github/workflows/pq_pipeline.yml index 41d5ac99c9..4a5107a485 100644 --- a/.github/workflows/pq_pipeline.yml +++ b/.github/workflows/pq_pipeline.yml @@ -80,7 +80,7 @@ jobs: (github.event_name == 'schedule' && github.repository_owner == 'securefederatedai') || (github.event_name == 'workflow_dispatch') name: TaskRunner E2E - needs: set_commit_id_for_all_jobs + needs: task_runner_connectivity_e2e uses: ./.github/workflows/task_runner_basic_e2e.yml with: commit_id: ${{ needs.set_commit_id_for_all_jobs.outputs.commit_id }} @@ -90,7 +90,7 @@ jobs: (github.event_name == 'schedule' && github.repository_owner == 'securefederatedai') || (github.event_name == 'workflow_dispatch') name: TaskRunner Resiliency E2E - needs: task_runner_e2e + needs: task_runner_connectivity_e2e uses: ./.github/workflows/task_runner_resiliency_e2e.yml with: commit_id: ${{ needs.set_commit_id_for_all_jobs.outputs.commit_id }} @@ -158,6 +158,26 @@ jobs: with: commit_id: ${{ needs.set_commit_id_for_all_jobs.outputs.commit_id }} + task_runner_fed_analytics_e2e: + if: | + (github.event_name == 'schedule' && github.repository_owner == 'securefederatedai') || + (github.event_name == 'workflow_dispatch') + name: TaskRunner Federated Analytics E2E + needs: task_runner_connectivity_e2e + uses: ./.github/workflows/task_runner_fed_analytics_e2e.yml + with: + commit_id: ${{ needs.set_commit_id_for_all_jobs.outputs.commit_id }} + + tr_verifiable_dataset_e2e: + if: | + (github.event_name == 'schedule' && github.repository_owner == 'securefederatedai') || + (github.event_name == 'workflow_dispatch') + name: TaskRunner Verifiable Dataset E2E + needs: task_runner_e2e + uses: ./.github/workflows/tr_verifiable_dataset_e2e.yml + with: + commit_id: ${{ needs.set_commit_id_for_all_jobs.outputs.commit_id }} + run_trivy: if: | (github.event_name == 'schedule' && github.repository_owner == 'securefederatedai') || @@ -198,7 +218,9 @@ jobs: wf_mnist_local_runtime, wf_watermark_e2e, wf_secagg_e2e, + task_runner_connectivity_e2e, task_runner_e2e, + task_runner_fed_analytics_e2e, task_runner_resiliency_e2e, task_runner_fedeval_e2e, task_runner_secure_agg_e2e, @@ -206,6 +228,7 @@ jobs: task_runner_dockerized_e2e, task_runner_secret_ssl_e2e, task_runner_flower_app_pytorch, + tr_verifiable_dataset_e2e, run_trivy, run_bandit ] diff --git a/.github/workflows/task_runner_fed_analytics_e2e.yml b/.github/workflows/task_runner_fed_analytics_e2e.yml new file mode 100644 index 0000000000..0fc5b4caf0 --- /dev/null +++ b/.github/workflows/task_runner_fed_analytics_e2e.yml @@ -0,0 +1,108 @@ +--- +# Task Runner Federated Analytics E2E tests for bare metal approach + +name: Task_Runner_Fed_Analytics_E2E # Please do not modify the name as it is used in the composite action + +on: + workflow_call: + inputs: + commit_id: + required: false + type: string + workflow_dispatch: + inputs: + num_collaborators: + description: "Number of collaborators" + required: false + default: "2" + type: string + python_version: + description: "Python version" + required: false + default: "3.10" + type: choice + options: + - "3.10" + - "3.11" + - "3.12" + +permissions: + contents: read + +# Environment variables common for all the jobs +# DO NOT use double quotes for the values of the environment variables +env: + NUM_COLLABORATORS: ${{ inputs.num_collaborators || 2 }} + COMMIT_ID: ${{ inputs.commit_id || github.sha }} # use commit_id from the calling workflow + +jobs: + test_fed_analytics_histogram: + name: With TLS (federated_analytics/histogram, 3.11) # DO NOT change this name. + runs-on: ubuntu-22.04 + timeout-minutes: 30 + if: | + (github.event_name == 'schedule' && github.repository_owner == 'securefederatedai') || + (github.event_name == 'workflow_dispatch') || + (github.event.pull_request.draft == false) + env: + MODEL_NAME: 'federated_analytics/histogram' + PYTHON_VERSION: ${{ inputs.python_version || '3.11' }} + + steps: + - name: Checkout OpenFL repository + id: checkout_openfl + uses: actions/checkout@v4 + with: + ref: ${{ env.COMMIT_ID }} + + - name: Pre test run + uses: ./.github/actions/tr_pre_test_run + if: ${{ always() }} + + - name: Run Federated Analytics Histogram + id: run_tests + run: | + python -m pytest -s tests/end_to_end/test_suites/tr_fed_analytics_tests.py \ + -m task_runner_fed_analytics --model_name ${{ env.MODEL_NAME }} --num_collaborators ${{ env.NUM_COLLABORATORS }} + echo "Federated analytics histogram test run completed" + + - name: Post test run + uses: ./.github/actions/tr_post_test_run + if: ${{ always() }} + with: + test_type: "Sepal_Histogram_Analytics" + + test_fed_analytics_smokers_health: + name: With TLS (federated_analytics/smokers_health, 3.12) # DO NOT change this name. + runs-on: ubuntu-22.04 + timeout-minutes: 30 + if: | + (github.event_name == 'schedule' && github.repository_owner == 'securefederatedai') || + (github.event_name == 'workflow_dispatch') || + (github.event.pull_request.draft == false) + env: + MODEL_NAME: 'federated_analytics/smokers_health' + PYTHON_VERSION: ${{ inputs.python_version || '3.12' }} + steps: + - name: Checkout OpenFL repository + id: checkout_openfl + uses: actions/checkout@v4 + with: + ref: ${{ env.COMMIT_ID }} + + - name: Pre test run + uses: ./.github/actions/tr_pre_test_run + if: ${{ always() }} + + - name: Run Federated Analytics Smokers Health + id: run_tests + run: | + python -m pytest -s tests/end_to_end/test_suites/tr_fed_analytics_tests.py \ + -m task_runner_fed_analytics --model_name ${{ env.MODEL_NAME }} --num_collaborators ${{ env.NUM_COLLABORATORS }} + echo "Federated analytics smokers health test run completed" + + - name: Post test run + uses: ./.github/actions/tr_post_test_run + if: ${{ always() }} + with: + test_type: "Smokers_Health_Analytics" \ No newline at end of file diff --git a/.github/workflows/tr_verifiable_dataset_e2e.yml b/.github/workflows/tr_verifiable_dataset_e2e.yml new file mode 100644 index 0000000000..e2f77095a2 --- /dev/null +++ b/.github/workflows/tr_verifiable_dataset_e2e.yml @@ -0,0 +1,81 @@ +--- +# Task Runner Verifiable Dataset E2E + +name: TR_Verifiable_Dataset_E2E # Please do not modify the name as it is used in the composite action + +on: + workflow_call: + inputs: + commit_id: + required: false + type: string + workflow_dispatch: + inputs: + num_rounds: + description: "Number of rounds to train" + required: false + default: "2" + type: string + num_collaborators: + description: "Number of collaborators" + required: false + default: "2" + type: string + +permissions: + contents: read + +# Environment variables common for all the jobs +# DO NOT use double quotes for the values of the environment variables +env: + NUM_ROUNDS: ${{ inputs.num_rounds || 2 }} + NUM_COLLABORATORS: ${{ inputs.num_collaborators || 2 }} + COMMIT_ID: ${{ inputs.commit_id || github.sha }} # use commit_id from the calling workflow + +jobs: + test_with_s3: # Run it only if the runner machine has enough memory and CPU + name: With S3 (torch/histology_s3, 3.11) + runs-on: ubuntu-22.04 + timeout-minutes: 120 + env: + MODEL_NAME: "torch/histology_s3" + PYTHON_VERSION: "3.11" + + steps: + - name: Checkout OpenFL repository + id: checkout_openfl + uses: actions/checkout@v4 + with: + ref: ${{ env.COMMIT_ID }} + + - name: Pre test run + uses: ./.github/actions/tr_pre_test_run + if: ${{ always() }} + + - name: Install MinIO + id: install_minio + run: | + wget https://dl.min.io/server/minio/release/linux-amd64/minio + chmod +x minio + sudo mv minio /usr/local/bin/ + + - name: Install MinIO Client + id: install_minio_client + run: | + wget https://dl.min.io/client/mc/release/linux-amd64/mc + chmod +x mc + sudo mv mc /usr/local/bin/ + + - name: Run Task Runner E2E tests with S3 + id: run_tests + run: | + python -m pytest -s tests/end_to_end/test_suites/tr_verifiable_dataset_tests.py \ + -m task_runner_with_s3 --model_name ${{ env.MODEL_NAME }} \ + --num_rounds ${{ env.NUM_ROUNDS }} --num_collaborators ${{ env.NUM_COLLABORATORS }} + echo "Task Runner E2E tests with S3 run completed" + + - name: Post test run + uses: ./.github/actions/tr_post_test_run + if: ${{ always() }} + with: + test_type: "With_S3" diff --git a/openfl-workspace/flower-app-pytorch/README.md b/openfl-workspace/flower-app-pytorch/README.md index c42ee65a61..a18d6059c8 100644 --- a/openfl-workspace/flower-app-pytorch/README.md +++ b/openfl-workspace/flower-app-pytorch/README.md @@ -28,7 +28,7 @@ Then create a certificate authority (CA) fx workspace certify ``` -This will create a workspace in your current working directory called `./my_workspace` as well as install the Flower app defined in `./app-pytorch.` This will be where the experiment takes place. The CA will be used to sign the certificates of the collaborators. +This will create a workspace in your current working directory called `./my_workspace` as well as install the Flower app defined in `./src/app-pytorch.` This will be where the experiment takes place. The CA will be used to sign the certificates of the collaborators. ### Setup Data We will be using CIFAR10 dataset. You can install an automatically partition it into 2 using the `./src/setup_data.py` script provided. @@ -63,44 +63,52 @@ data/ Notice under `./plan`, you will find the familiar OpenFL YAML files to configure the experiment. `cols.yaml` and `data.yaml` will be populated by the collaborators that will run the Flower client app and the respective data shard or directory they will perform their training and testing on. `plan.yaml` configures the experiment itself. The Open-Flower integration makes a few key changes to the `plan.yaml`: -1. Introduction of a new top-level key (`connector`) to configure a newly introduced component called `ConnectorFlower`. This component is run by the aggregator and is responsible for initializing the Flower `SuperLink` and connecting to the OpenFL server. The `SuperLink` parameters can be configured using `connector.settings.superlink_params`. If nothing is supplied, it will simply run `flower-superlink --insecure` with the command's default settings as dictated by Flower. It also includes the option to run the flwr run command via `connector.settings.flwr_run_params`. If `flwr_run_params` are not provided, the user will be expected to run `flwr run ` from the aggregator machine to initiate the experiment. Additionally, the `ConnectorFlower` has an additional setting `connector.settings.automatic_shutdown` which is default set to `True`. When set to `True`, the task runner will shut the SuperNode at the completion of an experiment, otherwise, it will run continuously. +1. Introduction of a new top-level key (`connector`) to configure a newly introduced component called `ConnectorFlower`. This component is run by the aggregator and is responsible for initializing the Flower `SuperLink` and connecting to the OpenFL server. Under `settings`, you will find the parameters for both the `flower-superlink` and `flower run` commands. All parameters are configuration by the user. By default, the `flower-superlink` will be run in `insecure` mode. The default `fleet_api_port` and `exec_api_port` will be automatically assigned, while the `exec_api_port` should be set to match the address configured in `./src/app-pytorch/pyproject.toml`. This is not set dynamically. In addition, since OpenFL handles cross network communication, `superlink_host` is set to a local host by default. For the `flwr run` command, the user should ensure that the `federation_name` and `flwr_app_name` is consistent with what is defined in `./src/` (if different than `app-pytorch`) and `./src//pyproject.toml`. The Flower directory `flwr_dir` is set to save the FAB in `save/.flwr`. Should a user configure this, the save directory must be located inside the workspace. Additionally, the `ConnectorFlower` has a setting `automatic_shutdown` which is default set to `True`. When set to `True`, the task runner will shut the SuperNode at the completion of an experiment, otherwise, it will run continuously. ```yaml connector: defaults: plan/defaults/connector.yaml template: openfl.component.ConnectorFlower settings: - automatic_shutdown: True - superlink_params: - insecure: True - serverappio-api-address: 127.0.0.1:9091 - fleet-api-address: 127.0.0.1:9092 - exec-api-address: 127.0.0.1:9093 - flwr_run_params: - flwr_app_name: "app-pytorch" - federation_name: "local-poc" + automatic_shutdown: true + insecure: true + exec_api_port: 9093 + fleet_api_port: 57085 + serverappio_api_port: 58873 + federation_name: local-poc + flwr_app_name: app-pytorch + flwr_dir: save/.flwr + superlink_host: 127.0.0.1 ``` -2. `FlowerTaskRunner` which will execute the `start_client_adapter` task. This task starts the Flower SuperNode and makes a connection to the OpenFL client. +2. `FlowerTaskRunner` which will execute the `start_client_adapter` task. This task starts the Flower SuperNode and makes a connection to the OpenFL client. In addition, you will notice there are settings for the `flwr_app_name`, `flwr_dir`, and `sgx_enabled`. `flwr_app_name` and `flwr_dir` are for prebuilding and installing the Flower app and should follow the convention as the `Connector` settings. `sgx_enabled` enables secure execution of the Flower `ClientApp` within an Intel® SGX enclave. When set to `True`, the task runner will launch the client app in an isolated process suitable for enclave execution and handle additional setup required for SGX compatibility (see [Running in Intel® SGX Enclave](#running-in-intel®-sgx-enclave) for details). ```yaml task_runner: defaults: plan/defaults/task_runner.yaml - template: openfl.federated.task.runner_flower.FlowerTaskRunner + template: openfl.federated.task.FlowerTaskRunner + settings : + flwr_app_name : app-pytorch + flwr_dir : save/.flwr + sgx_enabled: False ``` 3. `FlowerDataLoader` with similar high-level functionality to other dataloaders. -4. `Task` - we introduce a `tasks_connector.yaml` that will allow the collaborator to connect to Flower framework via the interop server. It also handles the task runner's `start_client_adapter` method, which actually starts the Flower component and interop server. By setting `local_server_port` to 0, the port is dynamically allocated. This is mainly for local experiments to avoid overlapping the ports. +4. `Task` - we introduce a `tasks_connector.yaml` that will allow the collaborator to connect to Flower framework via the interop server. It also handles the task runner's `start_client_adapter` method, which actually starts the Flower component and interop server. In the `settings`, the `interop_server` points to the `FlowerInteropServer` module that will establish connect between the OpenFL client and the Flower `SuperNode`. Like the `SuperLink` the host is set to the local host because OpenFL handles cross network communication. The `interop_server_port` and `clientappio_api_port` are automatically allocated by OpenFL. Setting `local_simulation` to `True` will further offest the ports based on the collaborator names in order to avoid overlapping ports. This is not an issue when collaborators are remote. ```yaml tasks: prepare_for_interop: function: start_client_adapter - kwargs: - interop_server_port: 0 + kwargs: {} settings: - interop_server: src.grpc.connector.flower.interop_server + clientappio_api_port: 59731 + interop_server: openfl.transport.grpc.interop.FlowerInteropServer + interop_server_host: 127.0.0.1 + interop_server_port: 51807 + local_simulation: true + ``` 5.`Collaborator` has an additional setting `interop_mode` which will invoke a callback to prepare the interop server that'll eventually be started by the Task Runner diff --git a/openfl-workspace/flower-app-pytorch/plan/plan.yaml b/openfl-workspace/flower-app-pytorch/plan/plan.yaml index f8d9984e66..a0bf1debf2 100644 --- a/openfl-workspace/flower-app-pytorch/plan/plan.yaml +++ b/openfl-workspace/flower-app-pytorch/plan/plan.yaml @@ -10,18 +10,13 @@ aggregator : write_logs : false connector : - defaults : plan/defaults/connector.yaml - template : src.connector_flower.ConnectorFlower + defaults : plan/defaults/connector_flower.yaml + template : openfl.component.ConnectorFlower settings : - automatic_shutdown : True - superlink_params : - insecure : True - serverappio-api-address : 127.0.0.1:9091 - fleet-api-address : 127.0.0.1:9092 - exec-api-address : 127.0.0.1:9093 - flwr_run_params : - flwr_app_name : "app-pytorch" - federation_name : "local-poc" + exec_api_port : 9093 + flwr_app_name : app-pytorch + federation_name : local-poc + flwr_dir : save/.flwr collaborator : defaults : plan/defaults/collaborator.yaml @@ -31,15 +26,14 @@ collaborator : data_loader : defaults : plan/defaults/data_loader.yaml - template : src.loader.FlowerDataLoader - settings : - collaborator_count : 2 + template : openfl.federated.data.FlowerDataLoader task_runner : defaults : plan/defaults/task_runner.yaml - template : src.runner.FlowerTaskRunner + template : openfl.federated.task.FlowerTaskRunner settings : - flwr_app_name: app-pytorch + flwr_app_name : app-pytorch + flwr_dir : save/.flwr sgx_enabled: False network : @@ -56,9 +50,7 @@ assigner : - prepare_for_interop tasks : - defaults : plan/defaults/tasks_connector.yaml - settings : - interop_server: src.grpc.connector.flower.interop_server + defaults : plan/defaults/tasks_flower.yaml compression_pipeline : defaults : plan/defaults/compression_pipeline.yaml \ No newline at end of file diff --git a/openfl-workspace/flower-app-pytorch/src/app-pytorch/pyproject.toml b/openfl-workspace/flower-app-pytorch/src/app-pytorch/pyproject.toml index 32a51b068d..64061c7542 100644 --- a/openfl-workspace/flower-app-pytorch/src/app-pytorch/pyproject.toml +++ b/openfl-workspace/flower-app-pytorch/src/app-pytorch/pyproject.toml @@ -8,7 +8,7 @@ version = "1.0.0" description = "" license = "Apache-2.0" dependencies = [ - "flwr-nightly", + "flwr-nightly==1.19.0.dev20250513", "flwr-datasets[vision]>=0.5.0", "torch==2.5.1", "torchvision==0.20.1", diff --git a/openfl-workspace/flower-app-pytorch/src/grpc/__init__.py b/openfl-workspace/flower-app-pytorch/src/grpc/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/openfl-workspace/flower-app-pytorch/src/grpc/connector/__init__.py b/openfl-workspace/flower-app-pytorch/src/grpc/connector/__init__.py deleted file mode 100644 index 035174e6f2..0000000000 --- a/openfl-workspace/flower-app-pytorch/src/grpc/connector/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from src.grpc.connector.utils import get_interop_server diff --git a/openfl-workspace/flower-app-pytorch/src/grpc/connector/flower/__init__.py b/openfl-workspace/flower-app-pytorch/src/grpc/connector/flower/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/openfl-workspace/flower-app-pytorch/src/grpc/connector/utils.py b/openfl-workspace/flower-app-pytorch/src/grpc/connector/utils.py deleted file mode 100644 index 0202346ea9..0000000000 --- a/openfl-workspace/flower-app-pytorch/src/grpc/connector/utils.py +++ /dev/null @@ -1,10 +0,0 @@ -import importlib - -def get_interop_server(framework: str = 'Flower') -> object: - if framework == 'Flower': - try: - module = importlib.import_module('src.grpc.connector.flower.interop_server') - return module.FlowerInteropServer - except ImportError: - print("Flower is not installed.") - return None diff --git a/openfl-workspace/flower-app-pytorch/src/util.py b/openfl-workspace/flower-app-pytorch/src/util.py deleted file mode 100644 index 750eff8f2a..0000000000 --- a/openfl-workspace/flower-app-pytorch/src/util.py +++ /dev/null @@ -1,13 +0,0 @@ -import re - -def is_safe_path(path): - """ - Validate the path to ensure it contains only allowed characters. - - Args: - path (str): The path to validate. - - Returns: - bool: True if the path is safe, False otherwise. - """ - return re.match(r'^[\w\-/\.]+$', path) is not None diff --git a/openfl-workspace/workspace/plan/defaults/connector.yaml b/openfl-workspace/workspace/plan/defaults/connector.yaml deleted file mode 100644 index 2b6645d22b..0000000000 --- a/openfl-workspace/workspace/plan/defaults/connector.yaml +++ /dev/null @@ -1 +0,0 @@ -template : openfl.component.Connector \ No newline at end of file diff --git a/openfl-workspace/workspace/plan/defaults/connector_flower.yaml b/openfl-workspace/workspace/plan/defaults/connector_flower.yaml new file mode 100644 index 0000000000..66a7e5e914 --- /dev/null +++ b/openfl-workspace/workspace/plan/defaults/connector_flower.yaml @@ -0,0 +1,8 @@ +template : openfl.component.ConnectorFlower +settings : + automatic_shutdown : True + insecure : True + superlink_host : 127.0.0.1 + serverappio_api_port : auto + fleet_api_port : auto + exec_api_port : auto \ No newline at end of file diff --git a/openfl-workspace/workspace/plan/defaults/tasks_connector.yaml b/openfl-workspace/workspace/plan/defaults/tasks_connector.yaml deleted file mode 100644 index 71b4db5bda..0000000000 --- a/openfl-workspace/workspace/plan/defaults/tasks_connector.yaml +++ /dev/null @@ -1,4 +0,0 @@ -prepare_for_interop: - function : start_client_adapter - kwargs : - interop_server_port : 0 # interop server port, 0 to dynamically allocate diff --git a/openfl-workspace/workspace/plan/defaults/tasks_flower.yaml b/openfl-workspace/workspace/plan/defaults/tasks_flower.yaml new file mode 100644 index 0000000000..6f0da7841e --- /dev/null +++ b/openfl-workspace/workspace/plan/defaults/tasks_flower.yaml @@ -0,0 +1,11 @@ +prepare_for_interop: + function : start_client_adapter + kwargs : + {} + +settings: + interop_server : openfl.transport.grpc.interop.FlowerInteropServer + interop_server_host : 127.0.0.1 + interop_server_port : auto + clientappio_api_port : auto + local_simulation : True diff --git a/openfl/component/__init__.py b/openfl/component/__init__.py index 5b0a22c487..d9e91981ae 100644 --- a/openfl/component/__init__.py +++ b/openfl/component/__init__.py @@ -3,6 +3,8 @@ """OpenFL Component Module.""" +from importlib import util + from openfl.component.aggregator.aggregator import Aggregator from openfl.component.aggregator.straggler_handling import ( CutoffTimePolicy, @@ -14,3 +16,6 @@ from openfl.component.assigner.random_grouped_assigner import RandomGroupedAssigner from openfl.component.assigner.static_grouped_assigner import StaticGroupedAssigner from openfl.component.collaborator.collaborator import Collaborator + +if util.find_spec("flwr") is not None: + from openfl.component.connector.connector_flower import ConnectorFlower diff --git a/openfl/component/collaborator/collaborator.py b/openfl/component/collaborator/collaborator.py index 178bd8076a..dc0c898aaf 100644 --- a/openfl/component/collaborator/collaborator.py +++ b/openfl/component/collaborator/collaborator.py @@ -7,6 +7,7 @@ import importlib import logging from enum import Enum +from os.path import splitext from time import sleep from typing import List, Optional @@ -607,8 +608,10 @@ def prepare_interop_server(self): """ # Initialize the interop server - framework = self.task_config["settings"]["interop_server"] - module = importlib.import_module(framework) + interop_server_template = self.task_config["settings"]["interop_server"] + interop_server_class = splitext(interop_server_template)[1].strip(".") + interop_server_module_path = splitext(interop_server_template)[0] + interop_server_module = importlib.import_module(interop_server_module_path) def receive_message_from_interop(message): """Receive message from interop server.""" @@ -616,5 +619,11 @@ def receive_message_from_interop(message): response = self.client.send_message_to_server(message, self.collaborator_name) return response - interop_server = module.FlowerInteropServer(receive_message_from_interop) + interop_server = getattr(interop_server_module, interop_server_class)( + receive_message_from_interop + ) + # Pass all keys in self.task_config['settings'] through to prepare_for_interop kwargs + self.task_config["prepare_for_interop"]["kwargs"].update( + self.task_config.get("settings", {}) + ) self.task_config["prepare_for_interop"]["kwargs"]["interop_server"] = interop_server diff --git a/openfl/component/connector/__init__.py b/openfl/component/connector/__init__.py new file mode 100644 index 0000000000..482e80e33c --- /dev/null +++ b/openfl/component/connector/__init__.py @@ -0,0 +1,9 @@ +# Copyright 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""OpenFL Connector Module.""" + +from importlib import util + +if util.find_spec("flwr") is not None: + from openfl.component.connector.connector_flower import ConnectorFlower diff --git a/openfl-workspace/flower-app-pytorch/src/connector_flower.py b/openfl/component/connector/connector_flower.py similarity index 53% rename from openfl-workspace/flower-app-pytorch/src/connector_flower.py rename to openfl/component/connector/connector_flower.py index afe7057201..24a8d35cb2 100644 --- a/openfl-workspace/flower-app-pytorch/src/connector_flower.py +++ b/openfl/component/connector/connector_flower.py @@ -1,22 +1,19 @@ -from logging import getLogger -logger = getLogger(__name__) +# Copyright 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 -import psutil +import os +import signal import subprocess import sys -import signal +from logging import getLogger -from src.grpc.connector.flower.interop_client import FlowerInteropClient -from src.util import is_safe_path +import psutil -import os +from openfl.transport.grpc.interop import FlowerInteropClient +from openfl.utilities.path_check import is_directory_traversal -flwr_home = os.path.join(os.getcwd(), "save/.flwr") -if not is_safe_path(flwr_home): - raise ValueError("Invalid path for FLWR_HOME") +logger = getLogger(__name__) -os.environ["FLWR_HOME"] = flwr_home -os.makedirs(os.environ["FLWR_HOME"], exist_ok=True) class ConnectorFlower: """ @@ -24,30 +21,66 @@ class ConnectorFlower: This class is responsible for constructing and managing the execution of Flower server commands. """ - def __init__(self, - superlink_params: dict, - flwr_run_params: dict = None, - automatic_shutdown: bool = True, - **kwargs): + def __init__( + self, + superlink_host, + fleet_api_port, + exec_api_port, + serverappio_api_port, + insecure=True, + flwr_app_name=None, + federation_name=None, + automatic_shutdown=True, + flwr_dir=None, + **kwargs, + ): """ Initialize the ConnectorFlower instance by setting up the necessary server commands. Args: - superlink_params (dict): Configuration settings for the Flower server. - flwr_run_params (dict, optional): Parameters for running the Flower application. - automatic_shutdown (bool, optional): Flag to enable automatic shutdown of the server. Defaults to True. + superlink_host (str): Host address for the Flower SuperLink. + fleet_api_port (int): Port for the fleet API. + exec_api_port (int): Port for the exec API. + serverappio_api_port (int): Port for the serverappio API. + insecure (bool): Whether to use insecure connections. Defaults to True. + flwr_app_name (str, optional): Name of the Flower application to run. Defaults to None. + federation_name (str, optional): Name of the federation. Defaults to None. + automatic_shutdown (bool, optional): Whether to enable automatic shutdown. + Defaults to True. + flwr_dir (str, optional): Directory for Flower app within the OpenFL workspace. + Plan.yaml configuration defaults to `save/.flwr` **kwargs: Additional keyword arguments. """ super().__init__() self._process = None + self.flwr_dir = flwr_dir + if is_directory_traversal(self.flwr_dir): + logger.error("Flower app directory path is out of the OpenFL workspace scope.") + sys.exit(1) + else: + os.makedirs(self.flwr_dir, exist_ok=True) + os.environ["FLWR_HOME"] = self.flwr_dir + self.automatic_shutdown = automatic_shutdown self.signal_shutdown_sent = False - self.superlink_params = superlink_params + self.superlink_params = { + "insecure": insecure, + "exec_api_port": exec_api_port, + "fleet_api_port": fleet_api_port, + "serverappio_api_port": serverappio_api_port, + } + self.superlink_host = superlink_host self.flwr_superlink_command = self._build_flwr_superlink_command() - self.flwr_run_params = flwr_run_params + if flwr_app_name is None or federation_name is None: + self.flwr_run_params = None + else: + self.flwr_run_params = { + "flwr_app_name": flwr_app_name, + "federation_name": federation_name, + } self.flwr_run_command = self._build_flwr_run_command() if self.flwr_run_params else None self.interop_client = None @@ -55,12 +88,14 @@ def __init__(self, def get_interop_client(self): """ - Create and return a LocalGRPCClient instance using the superlink parameters. + Create and return a FlowerInteropClient instance using the superlink parameters. Returns: - LocalGRPCClient: An instance configured with the connector address and server rounds. + FlowerInteropClient: An instance configured with the connector address + and server rounds. """ - connector_address = self.superlink_params.get("fleet-api-address", "0.0.0.0:9092") + connector_port = self.superlink_params.get("fleet_api_port") + connector_address = f"{self.superlink_host}:{connector_port}" self.interop_client = FlowerInteropClient(connector_address, self.automatic_shutdown) return self.interop_client @@ -74,20 +109,20 @@ def _build_flwr_superlink_command(self) -> list[str]: command = ["flower-superlink", "--fleet-api-type", "grpc-adapter"] - if "insecure" in self.superlink_params and self.superlink_params["insecure"]: + if self.superlink_params.get("insecure"): command += ["--insecure"] - if "serverappio-api-address" in self.superlink_params: - command += ["--serverappio-api-address", str(self.superlink_params["serverappio-api-address"])] - # flwr default: 0.0.0.0:9091 + serverappio_api_port = self.superlink_params.get("serverappio_api_port") + serverappio_api_address = f"{self.superlink_host}:{serverappio_api_port}" + command += ["--serverappio-api-address", serverappio_api_address] - if "fleet-api-address" in self.superlink_params: - command += ["--fleet-api-address", str(self.superlink_params["fleet-api-address"])] - # flwr default: 0.0.0.0:9092 + fleet_api_port = self.superlink_params.get("fleet_api_port") + fleet_api_address = f"{self.superlink_host}:{fleet_api_port}" + command += ["--fleet-api-address", fleet_api_address] - if "exec-api-address" in self.superlink_params: - command += ["--exec-api-address", str(self.superlink_params["exec-api-address"])] - # flwr default: 0.0.0.0:9093 + exec_api_port = self.superlink_params.get("exec_api_port") + exec_api_address = f"{self.superlink_host}:{exec_api_port}" + command += ["--exec-api-address", exec_api_address] if self.automatic_shutdown: command += ["--isolation", "process"] @@ -105,11 +140,12 @@ def _build_flwr_serverapp_command(self) -> list[str]: """ command = ["flwr-serverapp", "--run-once"] - if "insecure" in self.superlink_params and self.superlink_params["insecure"]: + if self.superlink_params["insecure"]: command += ["--insecure"] - if "serverappio-api-address" in self.superlink_params: - command += ["--serverappio-api-address", str(self.superlink_params["serverappio-api-address"])] + serverappio_api_port = self.superlink_params["serverappio_api_port"] + serverappio_api_address = f"{self.superlink_host}:{serverappio_api_port}" + command += ["--serverappio-api-address", serverappio_api_address] return command @@ -120,7 +156,7 @@ def is_flwr_serverapp_running(self): Returns: bool: True if the ServerApp is running, False otherwise. """ - if not hasattr(self, 'flwr_serverapp_subprocess'): + if not hasattr(self, "flwr_serverapp_subprocess"): logger.debug("[OpenFL Connector] ServerApp was never started.") return False @@ -130,13 +166,19 @@ def is_flwr_serverapp_running(self): if not self.signal_shutdown_sent: self.signal_shutdown_sent = True - logger.info("[OpenFL Connector] Experiment has ended. Sending signal to shut down Flower components.") + logger.info( + "[OpenFL Connector] Experiment has ended. Sending signal " + "to shut down Flower components." + ) return False def _stop_flwr_serverapp(self): """Terminate the `flwr_serverapp` subprocess if it is still active.""" - if hasattr(self, 'flwr_serverapp_subprocess') and self.flwr_serverapp_subprocess.poll() is None: + if ( + hasattr(self, "flwr_serverapp_subprocess") + and self.flwr_serverapp_subprocess.poll() is None + ): logger.debug("[OpenFL Connector] ServerApp still running. Stopping...") self.flwr_serverapp_subprocess.terminate() try: @@ -162,20 +204,35 @@ def _build_flwr_run_command(self) -> list[str]: return command def start(self): - """Launch the `flower-superlink` and `flwr run` subprocesses using the constructed commands.""" + """ + Launch the `flower-superlink` and `flwr run` subprocesses + using the constructed commands. + """ if self._process is None: - logger.info(f"[OpenFL Connector] Starting server process: {' '.join(self.flwr_superlink_command)}") + logger.info( + f"[OpenFL Connector] Starting server process: " + f"{' '.join(self.flwr_superlink_command)}" + ) self._process = subprocess.Popen(self.flwr_superlink_command) logger.info(f"[OpenFL Connector] Server process started with PID: {self._process.pid}") else: logger.info("[OpenFL Connector] Server process is already running.") - if hasattr(self, 'flwr_run_command') and self.flwr_run_command: - logger.info(f"[OpenFL Connector] Starting `flwr run` subprocess: {' '.join(self.flwr_run_command)}") + if hasattr(self, "flwr_run_command") and self.flwr_run_command: + logger.info( + f"[OpenFL Connector] Starting `flwr run` " + f"subprocess: {' '.join(self.flwr_run_command)}" + ) subprocess.run(self.flwr_run_command) - if hasattr(self, 'flwr_serverapp_command') and self.flwr_serverapp_command: - self.interop_client.set_is_flwr_serverapp_running_callback(self.is_flwr_serverapp_running) + if hasattr(self, "flwr_serverapp_command") and self.flwr_serverapp_command: + logger.info( + f"[OpenFL Connector] Starting server app subprocess: " + f"{' '.join(self.flwr_serverapp_command)}" + ) + self.interop_client.set_is_flwr_serverapp_running_callback( + self.is_flwr_serverapp_running + ) self.flwr_serverapp_subprocess = subprocess.Popen(self.flwr_serverapp_command) def stop(self): @@ -183,11 +240,18 @@ def stop(self): self._stop_flwr_serverapp() if self._process: try: - logger.info(f"[OpenFL Connector] Stopping server process with PID: {self._process.pid}...") + logger.info( + f"[OpenFL Connector] Stopping server process with PID: {self._process.pid}..." + ) main_process = psutil.Process(self._process.pid) sub_processes = main_process.children(recursive=True) for sub_process in sub_processes: - logger.info(f"[OpenFL Connector] Stopping server subprocess with PID: {sub_process.pid}...") + logger.info( + ( + f"[OpenFL Connector] Stopping server subprocess " + f"with PID: {sub_process.pid}..." + ) + ) sub_process.terminate() _, still_alive = psutil.wait_procs(sub_processes, timeout=1) for p in still_alive: diff --git a/openfl/federated/__init__.py b/openfl/federated/__init__.py index a8b443c059..e17857d330 100644 --- a/openfl/federated/__init__.py +++ b/openfl/federated/__init__.py @@ -26,6 +26,9 @@ if util.find_spec("xgboost") is not None: from openfl.federated.data import XGBoostDataLoader from openfl.federated.task import XGBoostTaskRunner +if util.find_spec("flwr") is not None: + from openfl.federated.data import FlowerDataLoader + from openfl.federated.task import FlowerTaskRunner __all__ = [ "Plan", diff --git a/openfl/federated/data/__init__.py b/openfl/federated/data/__init__.py index 53e56a7f7d..29667f7b23 100644 --- a/openfl/federated/data/__init__.py +++ b/openfl/federated/data/__init__.py @@ -16,3 +16,6 @@ if util.find_spec("xgboost") is not None: from openfl.federated.data.loader_xgb import XGBoostDataLoader # NOQA + +if util.find_spec("flwr") is not None: + from openfl.federated.data.loader_flower import FlowerDataLoader # NOQA diff --git a/openfl-workspace/flower-app-pytorch/src/loader.py b/openfl/federated/data/loader_flower.py similarity index 96% rename from openfl-workspace/flower-app-pytorch/src/loader.py rename to openfl/federated/data/loader_flower.py index 0b63f60af0..1a4305b198 100644 --- a/openfl-workspace/flower-app-pytorch/src/loader.py +++ b/openfl/federated/data/loader_flower.py @@ -1,11 +1,12 @@ -# Copyright (C) 2024 Intel Corporation +# Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 """FlowerDataLoader module.""" -from openfl.federated.data.loader import DataLoader import os +from openfl.federated.data.loader import DataLoader + class FlowerDataLoader(DataLoader): """Flower Dataloader @@ -25,7 +26,7 @@ def __init__(self, data_path, **kwargs): Raises: FileNotFoundError: If the specified data path does not exist. - """ + """ super().__init__(**kwargs) if not os.path.exists(data_path): raise FileNotFoundError(f"The specified data path does not exist: {data_path}") diff --git a/openfl/federated/plan/plan.py b/openfl/federated/plan/plan.py index 8cdb5a13e5..4d0459cbbc 100644 --- a/openfl/federated/plan/plan.py +++ b/openfl/federated/plan/plan.py @@ -21,7 +21,7 @@ AggregatorRESTClient, AggregatorRESTServer, ) -from openfl.utilities.utils import getfqdn_env +from openfl.utilities.utils import generate_port, getfqdn_env SETTINGS = "settings" TEMPLATE = "template" @@ -317,9 +317,18 @@ def resolve(self): self.config["network"][SETTINGS]["agg_addr"] = getfqdn_env() if self.config["network"][SETTINGS]["agg_port"] == AUTO: - self.config["network"][SETTINGS]["agg_port"] = ( - int(self.hash[:8], 16) % (60999 - 49152) + 49152 - ) + self.config["network"][SETTINGS]["agg_port"] = generate_port(self.hash) + + if "connector" in self.config: + # automatically generate ports for Flower interoperability components + # if they are set to AUTO + for key, value in self.config["connector"][SETTINGS].items(): + if value == AUTO: + self.config["connector"][SETTINGS][key] = generate_port(self.hash) + + for key, value in self.config["tasks"][SETTINGS].items(): + if value == AUTO: + self.config["tasks"][SETTINGS][key] = generate_port(self.hash) def get_assigner(self): """Get the plan task assigner.""" diff --git a/openfl/federated/task/__init__.py b/openfl/federated/task/__init__.py index 7d1d7dfaeb..1763b3c54d 100644 --- a/openfl/federated/task/__init__.py +++ b/openfl/federated/task/__init__.py @@ -14,3 +14,5 @@ from openfl.federated.task.runner_pt import PyTorchTaskRunner # NOQA if util.find_spec("xgboost") is not None: from openfl.federated.task.runner_xgb import XGBoostTaskRunner # NOQA +if util.find_spec("flwr") is not None: + from openfl.federated.task.runner_flower import FlowerTaskRunner # NOQA diff --git a/openfl-workspace/flower-app-pytorch/src/runner.py b/openfl/federated/task/runner_flower.py similarity index 59% rename from openfl-workspace/flower-app-pytorch/src/runner.py rename to openfl/federated/task/runner_flower.py index 9fdbd8d619..dc4f4a598c 100644 --- a/openfl-workspace/flower-app-pytorch/src/runner.py +++ b/openfl/federated/task/runner_flower.py @@ -1,19 +1,23 @@ -from openfl.federated.task.runner import TaskRunner +# Copyright 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import hashlib +import logging +import os +import socket import subprocess -from logging import getLogger +import sys import time -import os -import numpy as np from pathlib import Path -import socket -from src.util import is_safe_path -flwr_home = os.path.join(os.getcwd(), "save/.flwr") -if not is_safe_path(flwr_home): - raise ValueError("Invalid path for FLWR_HOME") +import numpy as np + +from openfl.federated.task.runner import TaskRunner +from openfl.utilities.path_check import is_directory_traversal +from openfl.utilities.utils import generate_port + +logger = logging.getLogger(__name__) -os.environ["FLWR_HOME"] = flwr_home -os.makedirs(os.environ["FLWR_HOME"], exist_ok=True) class FlowerTaskRunner(TaskRunner): """ @@ -24,6 +28,7 @@ class FlowerTaskRunner(TaskRunner): in a subprocess. It provides options for both manual and automatic shutdown based on subprocess activity. """ + def __init__(self, **kwargs): """ Initialize the FlowerTaskRunner. @@ -33,29 +38,28 @@ def __init__(self, **kwargs): """ super().__init__(**kwargs) + self.flwr_dir = kwargs.get("flwr_dir") + if is_directory_traversal(self.flwr_dir): + logger.error("Flower app directory path is out of the OpenFL workspace scope.") + sys.exit(1) + else: + os.makedirs(self.flwr_dir, exist_ok=True) + os.environ["FLWR_HOME"] = self.flwr_dir + if self.data_loader is None: - flwr_app_name = kwargs.get('flwr_app_name') + flwr_app_name = kwargs.get("flwr_app_name") install_flower_FAB(flwr_app_name) return - self.sgx_enabled = kwargs.get('sgx_enabled') + self.sgx_enabled = kwargs.get("sgx_enabled") self.model = None - self.logger = getLogger(__name__) self.data_path = self.data_loader.get_node_configs() - self.client_port = kwargs.get('client_port') - if self.client_port is None: - self.client_port = get_dynamic_port() - self.shutdown_requested = False # Flag to signal shutdown - def start_client_adapter(self, - col_name=None, - round_num=None, - input_tensor_dict=None, - **kwargs): + def start_client_adapter(self, col_name=None, round_num=None, input_tensor_dict=None, **kwargs): """ Start the FlowerInteropServer and the Flower SuperNode. @@ -66,27 +70,43 @@ def start_client_adapter(self, **kwargs: Additional parameters for configuration. includes: interop_server (object): The FlowerInteropServer instance. + interop_server_host (str): The address of the interop server. + clientappio_api_port (int): The port for the clientappio API. + local_simulation (bool): Flag for local simulation to dynamically adjust ports. interop_server_port (int): The port for the interop server. """ def message_callback(): self.shutdown_requested = True - interop_server = kwargs.get('interop_server') - interop_server_port = kwargs.get('interop_server_port') - interop_server.set_end_experiment_callback(message_callback) - interop_server.start_server(interop_server_port) + interop_server = kwargs.get("interop_server") + interop_server_host = kwargs.get("interop_server_host") + interop_server_port = kwargs.get("interop_server_port") + clientappio_api_port = kwargs.get("clientappio_api_port") + + if kwargs.get("local_simulation"): + # Dynamically adjust ports for local simulation + logger.info(f"Adjusting ports for local simulation: {col_name}") + + interop_server_port = get_dynamic_port(interop_server_port, col_name) + clientappio_api_port = get_dynamic_port(clientappio_api_port, col_name) - # interop server sets port dynamically - interop_server_port = interop_server.get_port() + logger.info(f"Adjusted interop_server_port: {interop_server_port}") + logger.info(f"Adjusted clientappio_api_port: {clientappio_api_port}") + + interop_server.set_end_experiment_callback(message_callback) + interop_server.start_server(interop_server_host, interop_server_port) command = [ "flower-supernode", "--insecure", "--grpc-adapter", - "--superlink", f"127.0.0.1:{interop_server_port}", - "--clientappio-api-address", f"127.0.0.1:{self.client_port}", - "--node-config", f"data-path='{self.data_path}'" + "--superlink", + f"{interop_server_host}:{interop_server_port}", + "--clientappio-api-address", + f"{interop_server_host}:{clientappio_api_port}", + "--node-config", + f"data-path='{self.data_path}'", ] if self.sgx_enabled: @@ -94,34 +114,35 @@ def message_callback(): flwr_clientapp_command = [ "flwr-clientapp", "--insecure", - "--clientappio-api-address", f"127.0.0.1:{self.client_port}", + "--clientappio-api-address", + f"{interop_server_host}:{clientappio_api_port}", ] - self.logger.info("Starting Flower SuperNode process...") + logger.info("Starting Flower SuperNode process...") supernode_process = subprocess.Popen(command, shell=False) interop_server.handle_signals(supernode_process) if self.sgx_enabled: # Check if port is open before starting the client app - while not is_port_open('127.0.0.1', interop_server_port): + while not is_port_open(interop_server_host, interop_server_port): time.sleep(0.5) - time.sleep(1) # Add a small delay after confirming the port is open + time.sleep(1) # Add a small delay after confirming the port is open - self.logger.info("Starting Flower ClientApp process...") + logger.info("Starting Flower ClientApp process...") flwr_clientapp_process = subprocess.Popen(flwr_clientapp_command, shell=False) interop_server.handle_signals(flwr_clientapp_process) - self.logger.info("Press CTRL+C to stop the server and SuperNode process.") + logger.info("Press CTRL+C to stop the server and SuperNode process.") while not interop_server.termination_event.is_set(): if self.shutdown_requested: if self.sgx_enabled: - self.logger.info("Terminating Flower ClientApp process...") + logger.info("Terminating Flower ClientApp process...") interop_server.terminate_supernode_process(flwr_clientapp_process) flwr_clientapp_process.wait() - self.logger.info("Shutting down the server and SuperNode process...") + logger.info("Shutting down the server and SuperNode process...") interop_server.terminate_supernode_process(supernode_process) interop_server.stop_server() time.sleep(0.1) @@ -133,8 +154,6 @@ def message_callback(): return global_output_tensor_dict, local_output_tensor_dict - - def set_tensor_dict(self, tensor_dict, with_opt_vars=False): """ Set the tensor dictionary for the task runner. @@ -169,7 +188,7 @@ def save_native(self, filepath, **kwargs): if isinstance(filepath, Path): filepath = str(filepath) - assert filepath.endswith('.npz'), "Currently, only '.npz' file type is supported." + assert filepath.endswith(".npz"), "Currently, only '.npz' file type is supported." # Save the tensor dictionary to a .npz file np.savez(filepath, **self.tensor_dict) @@ -182,54 +201,43 @@ def get_required_tensorkeys_for_function(self, func_name, **kwargs): """Get tensor keys for functions. Return empty dict.""" return {} + def install_flower_FAB(flwr_app_name): """ - Build and install the patch for the Flower application. + Build and install Flower application. Args: - flwr_app_name (str): The name of the Flower application to patch. + flwr_app_name (str): The name of the Flower application. """ - flwr_dir = os.environ["FLWR_HOME"] - - # Change the current working directory to the Flower directory - os.chdir(flwr_dir) - # Run the build command - build_command = [ - "flwr", - "build", - "--app", - os.path.join("..", "..", "src", flwr_app_name) - ] + build_command = ["flwr", "build", "--app", os.path.join("src", flwr_app_name)] subprocess.check_call(build_command) # List .fab files after running the build command - fab_files = list(Path(flwr_dir).glob("*.fab")) + fab_files = list(Path.cwd().glob("*.fab")) # Determine the newest .fab file newest_fab_file = max(fab_files, key=os.path.getmtime) # Run the install command using the newest .fab file - subprocess.check_call([ - "flwr", - "install", - str(newest_fab_file) - ]) + install_command = ["flwr", "install", str(newest_fab_file)] + subprocess.check_call(install_command) + os.remove(newest_fab_file) + -def get_dynamic_port(): +def get_dynamic_port(base_port, collaborator_name): """ - Get a dynamically assigned port number. + Get a dynamically assigned port number based on collaborator name and base port. + This is only necessary for local simulation in order to avoid port conflicts. Returns: - int: An available port number assigned by the operating system. + int: The dynamically assigned port number. """ - # Create a socket - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - # Bind to port 0 to let the OS assign an available port - s.bind(('127.0.0.1', 0)) - # Get the assigned port number - port = s.getsockname()[1] - return port + combined_string = f"{base_port}--{collaborator_name}" + hash_object = hashlib.md5(combined_string.encode()) + hash_value = hash_object.hexdigest() + return generate_port(hash_value) + def is_port_open(host, port): """Check if a port is open on the given host.""" diff --git a/openfl/transport/grpc/interop/__init__.py b/openfl/transport/grpc/interop/__init__.py new file mode 100644 index 0000000000..d481116591 --- /dev/null +++ b/openfl/transport/grpc/interop/__init__.py @@ -0,0 +1,8 @@ +# Copyright 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from importlib import util + +if util.find_spec("flwr") is not None: + from openfl.transport.grpc.interop.flower.interop_client import FlowerInteropClient + from openfl.transport.grpc.interop.flower.interop_server import FlowerInteropServer diff --git a/openfl/transport/grpc/interop/flower/__init__.py b/openfl/transport/grpc/interop/flower/__init__.py new file mode 100644 index 0000000000..d481116591 --- /dev/null +++ b/openfl/transport/grpc/interop/flower/__init__.py @@ -0,0 +1,8 @@ +# Copyright 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from importlib import util + +if util.find_spec("flwr") is not None: + from openfl.transport.grpc.interop.flower.interop_client import FlowerInteropClient + from openfl.transport.grpc.interop.flower.interop_server import FlowerInteropServer diff --git a/openfl-workspace/flower-app-pytorch/src/grpc/connector/flower/interop_client.py b/openfl/transport/grpc/interop/flower/interop_client.py similarity index 82% rename from openfl-workspace/flower-app-pytorch/src/grpc/connector/flower/interop_client.py rename to openfl/transport/grpc/interop/flower/interop_client.py index 7159baf9e7..1f21ddbe38 100644 --- a/openfl-workspace/flower-app-pytorch/src/grpc/connector/flower/interop_client.py +++ b/openfl/transport/grpc/interop/flower/interop_client.py @@ -1,7 +1,14 @@ +# Copyright 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + import grpc from flwr.proto import grpcadapter_pb2_grpc -from src.grpc.connector.flower.message_conversion import flower_to_openfl_message, openfl_to_flower_message -from logging import getLogger + +from openfl.transport.grpc.interop.flower.message_conversion import ( + flower_to_openfl_message, + openfl_to_flower_message, +) + class FlowerInteropClient: """ @@ -9,6 +16,7 @@ class FlowerInteropClient: and the OpenFL Server. It converts messages between OpenFL and Flower formats and handles the send-receive communication with the Flower SuperNode using gRPC. """ + def __init__(self, superlink_address, automatic_shutdown=False): """ Initialize. @@ -23,8 +31,6 @@ def __init__(self, superlink_address, automatic_shutdown=False): self.end_experiment = False self.is_flwr_serverapp_running_callback = None - self.logger = getLogger(__name__) - def set_is_flwr_serverapp_running_callback(self, is_flwr_serverapp_running_callback): self.is_flwr_serverapp_running_callback = is_flwr_serverapp_running_callback @@ -47,8 +53,8 @@ def send_receive(self, openfl_message, header): # then the experiment has completed self.end_experiment = not self.is_flwr_serverapp_running_callback() - openfl_response = flower_to_openfl_message(flower_response, - header=header, - end_experiment=self.end_experiment) + openfl_response = flower_to_openfl_message( + flower_response, header=header, end_experiment=self.end_experiment + ) return openfl_response diff --git a/openfl-workspace/flower-app-pytorch/src/grpc/connector/flower/interop_server.py b/openfl/transport/grpc/interop/flower/interop_server.py similarity index 84% rename from openfl-workspace/flower-app-pytorch/src/grpc/connector/flower/interop_server.py rename to openfl/transport/grpc/interop/flower/interop_server.py index 16f104b576..732c5c8341 100644 --- a/openfl-workspace/flower-app-pytorch/src/grpc/connector/flower/interop_server.py +++ b/openfl/transport/grpc/interop/flower/interop_server.py @@ -1,14 +1,22 @@ +# Copyright 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + import logging -import threading import queue -import grpc +import signal +import threading +import time from concurrent.futures import ThreadPoolExecutor -from flwr.proto import grpcadapter_pb2_grpc -from src.grpc.connector.flower.message_conversion import flower_to_openfl_message, openfl_to_flower_message from multiprocessing import cpu_count -import signal + +import grpc import psutil -import time +from flwr.proto import grpcadapter_pb2_grpc + +from openfl.transport.grpc.interop.flower.message_conversion import ( + flower_to_openfl_message, + openfl_to_flower_message, +) logger = logging.getLogger(__name__) @@ -26,7 +34,8 @@ def __init__(self, send_message_to_client): Initialize. Args: - send_message_to_client (Callable): A callable function to send messages to the OpenFL client. + send_message_to_client (Callable): A callable function to send messages + to the OpenFL client. """ self.send_message_to_client = send_message_to_client self.end_experiment_callback = None @@ -41,18 +50,14 @@ def __init__(self, send_message_to_client): def set_end_experiment_callback(self, callback): self.end_experiment_callback = callback - def start_server(self, local_server_port): + def start_server(self, interop_server_host, interop_server_port): """Starts the gRPC server.""" self.server = grpc.server(ThreadPoolExecutor(max_workers=cpu_count())) grpcadapter_pb2_grpc.add_GrpcAdapterServicer_to_server(self, self.server) - self.port = self.server.add_insecure_port(f'[::]:{local_server_port}') + self.port = self.server.add_insecure_port(f"{interop_server_host}:{interop_server_port}") self.server.start() logger.info(f"OpenFL local gRPC server started, listening on port {self.port}.") - def get_port(self): - # Return the port that was assigned - return self.port - def stop_server(self): """Stops the gRPC server.""" if self.server: @@ -62,7 +67,10 @@ def stop_server(self): self.termination_event.set() def SendReceive(self, request, context): - """ Handles incoming gRPC requests by putting them into the request queue and waiting for the response. + """ + Handles incoming gRPC requests by putting them into the request + queue and waiting for the response. + Args: request: The incoming gRPC request. context: The gRPC context. @@ -87,8 +95,8 @@ def process_queue(self): openfl_response = self.send_message_to_client(openfl_request) # Check to end experiment - if hasattr(openfl_response, 'metadata'): - if openfl_response.metadata['end_experiment'] == 'True': + if hasattr(openfl_response, "metadata"): + if openfl_response.metadata["end_experiment"] == "True": self.end_experiment_callback() # Send response to Flower client @@ -98,6 +106,7 @@ def process_queue(self): def handle_signals(self, supernode_process): """Sets up signal handlers for graceful shutdown.""" + def signal_handler(_sig, _frame): self.terminate_supernode_process(supernode_process) self.stop_server() @@ -132,7 +141,10 @@ def terminate_process(self, process, timeout=5): process.terminate() process.wait(timeout=timeout) except psutil.TimeoutExpired: - logger.debug(f"Timeout expired while waiting for process {process.pid} to terminate. Killing the process.") + logger.debug( + f"Timeout expired while waiting for process {process.pid} " + "to terminate. Killing the process." + ) process.kill() except psutil.NoSuchProcess: logger.debug(f"Process {process.pid} does not exist. Skipping.") diff --git a/openfl-workspace/flower-app-pytorch/src/grpc/connector/flower/message_conversion.py b/openfl/transport/grpc/interop/flower/message_conversion.py similarity index 94% rename from openfl-workspace/flower-app-pytorch/src/grpc/connector/flower/message_conversion.py rename to openfl/transport/grpc/interop/flower/message_conversion.py index d900b83cf0..43be8b19de 100644 --- a/openfl-workspace/flower-app-pytorch/src/grpc/connector/flower/message_conversion.py +++ b/openfl/transport/grpc/interop/flower/message_conversion.py @@ -1,9 +1,12 @@ +# Copyright 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + from flwr.proto import grpcadapter_pb2 + from openfl.protocols import aggregator_pb2 -def flower_to_openfl_message(flower_message, - header=None, - end_experiment=False): + +def flower_to_openfl_message(flower_message, header=None, end_experiment=False): """ Convert a Flower MessageContainer to an OpenFL InteropMessage. @@ -40,6 +43,7 @@ def flower_to_openfl_message(flower_message, openfl_message.metadata.update({"end_experiment": str(end_experiment)}) return openfl_message + def openfl_to_flower_message(openfl_message): """ Convert an OpenFL InteropMessage to a Flower MessageContainer. diff --git a/openfl/utilities/utils.py b/openfl/utilities/utils.py index bab2ccc8c3..4f4e5fc2eb 100644 --- a/openfl/utilities/utils.py +++ b/openfl/utilities/utils.py @@ -263,3 +263,22 @@ def remove_readonly(func, path, _): func(path) return shutil.rmtree(path, ignore_errors=ignore_errors, onerror=remove_readonly) + + +def generate_port(hash, port_range=(49152, 60999)): + """ + Generate a deterministic port number based on a hash and a unique key. + + Args: + hash (str): A string representing the hash of the plan. + port_range (tuple): A tuple containing the minimum and maximum port + numbers (inclusive). The default range is (49152, 60999). + + Returns: + int: A port number within the specified range. + """ + min_port, max_port = port_range + # Use the first 8 characters of the unique hash to ensure deterministic output + hash_segment = hash[:8] + port = int(hash_segment, 16) % (max_port - min_port) + min_port + return port diff --git a/tests/end_to_end/models/collaborator.py b/tests/end_to_end/models/collaborator.py index 0729d59222..ccebde7913 100644 --- a/tests/end_to_end/models/collaborator.py +++ b/tests/end_to_end/models/collaborator.py @@ -246,3 +246,26 @@ def ping_aggregator(self): log.error(f"{error_msg}: {e}") raise e return True + + def calculate_hash(self): + """ + Calculate the hash of the data directory and store in hash.txt file + Returns: + bool: True if successful, else False + """ + try: + log.info(f"Calculating hash for {self.collaborator_name}") + cmd = f"fx collaborator calchash --data_path {self.data_directory_path}" + error_msg = "Failed to calculate hash" + return_code, output, error = fh.run_command( + cmd, + error_msg=error_msg, + container_id=self.container_id, + workspace_path=self.workspace_path, + ) + fh.verify_cmd_output(output, return_code, error, error_msg, f"Calculated hash for {self.collaborator_name}") + + except Exception as e: + log.error(f"{error_msg}: {e}") + raise e + return True diff --git a/tests/end_to_end/models/model_owner.py b/tests/end_to_end/models/model_owner.py index 3d2c8d0f9f..70da74e703 100644 --- a/tests/end_to_end/models/model_owner.py +++ b/tests/end_to_end/models/model_owner.py @@ -164,7 +164,7 @@ def modify_plan(self, param_config, plan_path): data["network"]["settings"]["require_client_auth"] = param_config.require_client_auth data["network"]["settings"]["use_tls"] = param_config.use_tls if param_config.tr_rest_api: - data["task_runner"]["settings"]["transport_protocol"] = "rest" + data["network"]["settings"]["transport_protocol"] = "rest" if param_config.secure_agg: data["aggregator"]["settings"]["secure_aggregation"] = True with open(plan_file, "w+") as write_file: diff --git a/tests/end_to_end/models/s3_bucket.py b/tests/end_to_end/models/s3_bucket.py new file mode 100644 index 0000000000..2a1151565a --- /dev/null +++ b/tests/end_to_end/models/s3_bucket.py @@ -0,0 +1,688 @@ +# Copyright 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import os +import subprocess +import time +import signal +import shutil +import atexit +import boto3 +import logging +from botocore.client import Config +from botocore.exceptions import ClientError +import fnmatch +from pathlib import Path + +import tests.end_to_end.utils.constants as constants + +log = logging.getLogger(__name__) + + +class S3Bucket(): + """ + A class to manage S3 bucket operations using boto3. + This class provides methods to create, delete, upload, download, + and list objects in S3 buckets, as well as manage MinIO server. + """ + + def __init__( + self, + endpoint_url=constants.MINIO_URL, + access_key=constants.MINIO_ROOT_USER, + secret_key=constants.MINIO_ROOT_PASSWORD, + region="us-east-1", + ): + """ + Initialize S3Helper with connection details. + + Args: + endpoint_url: The S3 endpoint URL (default: http://localhost:9000 for MinIO) + access_key: The access key (if None, uses MINIO_ROOT_USER env variable) + secret_key: The secret key (if None, uses MINIO_ROOT_PASSWORD env variable) + region: The region name (default: us-east-1, required by boto3 but not used by MinIO) + """ + self.endpoint_url = endpoint_url + self.access_key = access_key or os.environ.get("MINIO_ROOT_USER", "minioadmin") + self.secret_key = secret_key or os.environ.get( + "MINIO_ROOT_PASSWORD", "minioadmin" + ) + self.region = region + + # Extract host and port from endpoint_url + url_parts = self.endpoint_url.split('://')[-1].split(':') + self.minio_host = url_parts[0] + self.minio_port = int(url_parts[1]) if len(url_parts) > 1 else 9000 + + # Set default URLs + self.minio_url = f"{self.minio_host}:{self.minio_port}" + self.minio_console_url = f"{self.minio_host}:{self.minio_port + 1}" + + # Initialize S3 client + self.client = boto3.client( + "s3", + endpoint_url=self.endpoint_url, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version="s3v4"), + region_name=self.region, + ) + + def is_minio_server_running(self, host='localhost', port=9000): + """ + Check if a MinIO server is running on the specified host and port. + + Args: + host: Host name (default: localhost) + port: Port number (default: 9000) + + Returns: + bool: True if MinIO server is running, False otherwise + """ + try: + check_cmd = ['lsof', '-i', f':{port}', '-t'] + output = subprocess.check_output(check_cmd, universal_newlines=True).strip() + if output: + pids = [int(pid) for pid in output.split()] + log.info(f"Port {port} is in use (lsof check), PID(s): {pids}") + return pids + except Exception: + pass + return None + + def start_minio_server( + self, + data_dir, + access_key=None, + secret_key=None, + address=None, + console_address=None, + clean_start=True, + ): + """ + Start a MinIO server as a subprocess. + + Args: + data_dir: Directory to store data + access_key: MinIO access key (default: from instance) + secret_key: MinIO secret key (default: from instance) + address: Address to bind the MinIO server (default: from instance) + console_address: Address to bind the MinIO console (default: from instance) + clean_start: If True, terminate existing server and clean data (default: False) + + Returns: + subprocess.Popen: The process object for the MinIO server + """ + # Use instance values if not provided + address = address or self.minio_url + console_address = console_address or self.minio_console_url + access_key = access_key or self.access_key + secret_key = secret_key or self.secret_key + + # Parse address to get host and port + try: + host, port = address.split(':') + port = int(port) + except ValueError: + host = 'localhost' + port = 9001 + + # Check if MinIO server is already running + running = self.is_minio_server_running(host, port) + if running: + if not clean_start: + log.info("MinIO server already running. Skipping startup.") + return None + + log.info("MinIO server already running. Cleaning up for fresh start.") + + # If running is a list of PIDs, kill them + if isinstance(running, list): + for pid in running: + try: + os.kill(pid, signal.SIGTERM) + log.info(f"Killed MinIO process with PID {pid}") + except Exception as e: + log.warning(f"Could not kill PID {pid}: {e}") + time.sleep(2) # Give time for processes to terminate + else: + log.warning("MinIO server running but PID not found. Please check manually.") + + # Throw error if data_dir is not provided + if data_dir is None: + log.error("Data directory is required to start MinIO server.") + return None + + # Create data directory if it doesn't exist + os.makedirs(data_dir, exist_ok=True) + + # Check if minio is installed + minio_path = shutil.which("minio") + if minio_path is None: + log.error("MinIO server not found. Please install MinIO first.") + log.warning("You can download it from: https://min.io/download") + return None + + # Set environment variables for the current process as well as the subprocess + # This is important for MinIO to pick up the access and secret keys + # and for the subprocess to inherit them + env = os.environ.copy() + env["MINIO_ROOT_USER"] = os.environ["MINIO_ROOT_USER"] = access_key + env["MINIO_ROOT_PASSWORD"] = os.environ["MINIO_ROOT_PASSWORD"] = secret_key + + # Start MinIO server + cmd = [ + minio_path, + "server", + data_dir, + "--address", + address, + "--console-address", + console_address, + ] + log.info( + "Starting MinIO server with below configurations:" + f"\n - Data Directory: {data_dir}" + f"\n - Address: {address}" + f"\n - Console Address: {console_address}" + ) + + # Start the process + process = subprocess.Popen( + cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True + ) + + # Register a function to stop the server at exit + def stop_server(): + if process.poll() is None: # If process is still running + log.info("Stopping MinIO server...") + process.send_signal(signal.SIGTERM) + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + process.wait() + + atexit.register(stop_server) + + # Wait for server to start + time.sleep(2) + + # Check if server started successfully + if process.poll() is not None: + # Process exited already + out, err = process.communicate() + log.error("Failed to start MinIO server:") + log.info(f"STDOUT: {out}") + log.error(f"STDERR: {err}") + return None + + log.info("MinIO server started successfully.") + return process + + def create_bucket(self, bucket_name): + """ + Create a new bucket if it doesn't exist. + + Args: + bucket_name: Name of the bucket to create + + Returns: + bool: True if bucket was created or already exists, False on error + """ + try: + # Check if bucket already exists + self.client.head_bucket(Bucket=bucket_name) + log.info(f"Bucket {bucket_name} already exists.") + return True + except ClientError as e: + # If bucket doesn't exist, create it + if e.response["Error"]["Code"] == "404": + try: + self.client.create_bucket(Bucket=bucket_name) + log.info(f"Bucket {bucket_name} created successfully.") + return True + except ClientError as create_error: + log.error(f"Error creating bucket: {create_error}") + return False + else: + log.error(f"Error checking bucket: {e}") + return False + + def delete_bucket(self, bucket_name, force=False): + """ + Delete a bucket. + + Args: + bucket_name: Name of the bucket to delete + force: If True, delete all objects in the bucket before deletion + + Returns: + bool: True if bucket was deleted, False on error + """ + try: + if force: + # Delete all objects in the bucket first + self.delete_all_objects(bucket_name) + + # Delete the bucket + self.client.delete_bucket(Bucket=bucket_name) + log.info(f"Bucket {bucket_name} deleted successfully.") + return True + except ClientError as e: + log.error(f"Error deleting bucket {bucket_name}: {e}") + return False + + def list_buckets(self): + """ + List all buckets. + + Returns: + list: List of bucket names + """ + try: + response = self.client.list_buckets() + buckets = [bucket["Name"] for bucket in response.get("Buckets", [])] + log.info(f"Found {len(buckets)} buckets: {', '.join(buckets)}") + return buckets + except ClientError as e: + log.error(f"Error listing buckets: {e}") + return [] + + def upload_file(self, file_path, bucket_name, object_name=None): + """ + Upload a file to a bucket. + + Args: + file_path: Path to the file to upload + bucket_name: Name of the bucket + object_name: S3 object name (if None, uses file_path basename) + + Returns: + bool: True if file was uploaded, False on error + """ + # If object_name was not specified, use file_path basename + if object_name is None: + object_name = Path(file_path).name + + try: + self.client.upload_file(file_path, bucket_name, object_name) + log.debug(f"File {file_path} uploaded to {bucket_name}/{object_name}") + return True + except ClientError as e: + log.error(f"Error uploading file {file_path}: {e}") + return False + + def upload_directory(self, dir_path, bucket_name, prefix=""): + """ + Upload all files from a directory to a bucket. + + Args: + dir_path: Path to the directory to upload + bucket_name: Name of the bucket + prefix: Prefix to add to object names + + Returns: + int: Number of files uploaded + """ + dir_path = Path(dir_path) + count = 0 + + if not dir_path.is_dir(): + log.error(f"Error: {dir_path} is not a directory") + return count + + for root, _, files in os.walk(dir_path): + for file in files: + file_path = Path(root) / file + # Calculate relative path from dir_path + rel_path = file_path.relative_to(dir_path) + # Create object name with prefix + if prefix: + object_name = f"{prefix}/{rel_path}" + else: + object_name = str(rel_path) + + if self.upload_file(str(file_path), bucket_name, object_name): + count += 1 + + log.info(f"Uploaded {count} files to {bucket_name} from {dir_path}") + return count + + def download_file(self, bucket_name, object_name, file_path=None): + """ + Download a file from a bucket. + + Args: + bucket_name: Name of the bucket + object_name: S3 object name + file_path: Local path to save the file (if None, uses object_name basename) + + Returns: + bool: True if file was downloaded, False on error + """ + # If file_path was not specified, use object_name basename + if file_path is None: + file_path = Path(object_name).name + + try: + # Create directory if it doesn't exist + os.makedirs(Path(file_path).parent, exist_ok=True) + + self.client.download_file(bucket_name, object_name, file_path) + log.info(f"Downloaded {bucket_name}/{object_name} to {file_path}") + return True + except ClientError as e: + log.error(f"Error downloading {bucket_name}/{object_name}: {e}") + return False + + def download_directory(self, bucket_name, prefix, local_dir=None): + """ + Download all files with a prefix from a bucket. + + Args: + bucket_name: Name of the bucket + prefix: Prefix of objects to download + local_dir: Local directory to save files (if None, uses current dir) + + Returns: + int: Number of files downloaded + """ + if local_dir is None: + local_dir = "." + + local_dir = Path(local_dir) + os.makedirs(local_dir, exist_ok=True) + + count = 0 + try: + # List all objects with the prefix + paginator = self.client.get_paginator("list_objects_v2") + pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix) + + for page in pages: + if "Contents" not in page: + continue + + for obj in page["Contents"]: + object_name = obj["Key"] + + # Calculate relative path from prefix + if prefix and object_name.startswith(prefix): + rel_path = object_name[len(prefix) :] + if rel_path.startswith("/"): + rel_path = rel_path[1:] + else: + rel_path = object_name + + # Create local file path + file_path = local_dir / rel_path + + if self.download_file(bucket_name, object_name, str(file_path)): + count += 1 + + log.info( + f"Downloaded {count} files from {bucket_name}/{prefix} to {local_dir}" + ) + return count + except ClientError as e: + log.error(f"Error downloading from {bucket_name}/{prefix}: {e}") + return count + + def list_objects(self, bucket_name, prefix="", recursive=True, max_items=None): + """ + List objects in a bucket with an optional prefix. + + Args: + bucket_name: Name of the bucket + prefix: Prefix filter for objects + recursive: If False, emulates directory listing with delimiters + max_items: Maximum number of items to return + + Returns: + list: List of object keys + """ + try: + paginator = self.client.get_paginator("list_objects_v2") + + # Set up pagination parameters + pagination_config = {} + if max_items: + pagination_config["MaxItems"] = max_items + + # Set up operation parameters + operation_params = {"Bucket": bucket_name, "Prefix": prefix} + + # If not recursive, use delimiter to emulate directory listing + if not recursive: + operation_params["Delimiter"] = "/" + + # Get pages of objects + pages = paginator.paginate( + **operation_params, PaginationConfig=pagination_config + ) + + objects = [] + + for page in pages: + # Add objects + if "Contents" in page: + for obj in page["Contents"]: + objects.append(obj["Key"]) + + # Add common prefixes (folders) if not recursive + if not recursive and "CommonPrefixes" in page: + for prefix in page["CommonPrefixes"]: + objects.append(prefix["Prefix"]) + + log.info(f"Found {len(objects)} objects in {bucket_name}/{prefix}") + for obj in objects: + log.info(f"- {obj}") + + return objects + except ClientError as e: + log.error(f"Error listing objects in {bucket_name}/{prefix}: {e}") + return [] + + def delete_object(self, bucket_name, object_name): + """ + Delete an object from a bucket. + + Args: + bucket_name: Name of the bucket + object_name: S3 object name to delete + + Returns: + bool: True if object was deleted, False on error + """ + try: + self.client.delete_object(Bucket=bucket_name, Key=object_name) + log.info(f"Deleted {bucket_name}/{object_name}") + return True + except ClientError as e: + log.error(f"Error deleting {bucket_name}/{object_name}: {e}") + return False + + def delete_objects(self, bucket_name, object_names): + """ + Delete multiple objects from a bucket. + + Args: + bucket_name: Name of the bucket + object_names: List of object names to delete + + Returns: + int: Number of objects deleted + """ + if not object_names: + return 0 + + try: + # Create delete request + objects = [{"Key": obj} for obj in object_names] + response = self.client.delete_objects( + Bucket=bucket_name, Delete={"Objects": objects} + ) + + deleted = len(response.get("Deleted", [])) + errors = len(response.get("Errors", [])) + + log.info(f"Deleted {deleted} objects from {bucket_name}") + if errors > 0: + log.error(f"Failed to delete {errors} objects") + + return deleted + except ClientError as e: + log.error(f"Error deleting objects from {bucket_name}: {e}") + return 0 + + def delete_prefix(self, bucket_name, prefix): + """ + Delete all objects with a specific prefix (like a folder). + + Args: + bucket_name: Name of the bucket + prefix: Prefix of objects to delete + + Returns: + int: Number of objects deleted + """ + try: + # List all objects with the prefix + objects = self.list_objects(bucket_name, prefix) + + # Delete the objects in batches + count = 0 + batch_size = 1000 # S3 limits delete_objects to 1000 at a time + + for i in range(0, len(objects), batch_size): + batch = objects[i : i + batch_size] + count += self.delete_objects(bucket_name, batch) + + log.info(f"Deleted {count} objects from {bucket_name}/{prefix}") + return count + except ClientError as e: + log.error(f"Error deleting prefix {bucket_name}/{prefix}: {e}") + return 0 + + def delete_all_objects(self, bucket_name): + """ + Delete all objects in a bucket. + + Args: + bucket_name: Name of the bucket + + Returns: + int: Number of objects deleted + """ + return self.delete_prefix(bucket_name, "") + + def split_directory_to_buckets( + self, source_path, bucket_name, folder_names, split_folders=None + ): + """ + Split folders from a directory into separate folders in a bucket. + + Args: + source_path: Path to the directory containing folders to split + bucket_name: Name of the bucket to upload to + folder_names: List of folder names to upload + split_folders: Dictionary mapping folders to destination prefixes, + if None, splits into equal groups + + Returns: + dict: Mapping of destination prefixes to lists of folders uploaded + """ + source_path = Path(source_path) + if not source_path.is_dir(): + log.error(f"Error: {source_path} is not a directory") + return {} + + # Ensure bucket exists + self.create_bucket(bucket_name) + + # Get folders in source directory that match requested folder names + folders = [] + for folder_name in folder_names: + folder_path = source_path / folder_name + if folder_path.is_dir(): + folders.append(folder_name) + else: + log.warning(f"Warning: {folder_path} is not a directory, skipping") + + # If split_folders is None, create equal groups + if split_folders is None: + half = len(folders) // 2 + split_folders = {"1": folders[:half], "2": folders[half:]} + + result = {} + + # Upload each group of folders to the specified prefix + for prefix, group_folders in split_folders.items(): + result[prefix] = [] + + for folder in group_folders: + if folder in folders: + folder_path = source_path / folder + # Upload the folder with the prefix + upload_prefix = f"{prefix}/{folder}" + count = self.upload_directory( + folder_path, bucket_name, upload_prefix + ) + if count > 0: + result[prefix].append(folder) + log.info(f"Uploaded {folder} to {bucket_name}/{upload_prefix}") + + return result + + def copy_object(self, source_bucket, source_key, dest_bucket, dest_key=None): + """ + Copy an object within or between buckets. + + Args: + source_bucket: Source bucket name + source_key: Source object key + dest_bucket: Destination bucket name + dest_key: Destination object key (if None, uses source_key) + + Returns: + bool: True if object was copied, False on error + """ + if dest_key is None: + dest_key = source_key + + try: + copy_source = {"Bucket": source_bucket, "Key": source_key} + + self.client.copy_object( + CopySource=copy_source, Bucket=dest_bucket, Key=dest_key + ) + + log.info(f"Copied {source_bucket}/{source_key} to {dest_bucket}/{dest_key}") + return True + except ClientError as e: + log.error(f"Error copying {source_bucket}/{source_key}: {e}") + return False + + def search_objects(self, bucket_name, pattern, prefix=""): + """ + Search for objects in a bucket using a glob pattern. + + Args: + bucket_name: Name of the bucket + pattern: Glob pattern to match object keys against + prefix: Optional prefix to limit search scope + + Returns: + list: List of matching object keys + """ + objects = self.list_objects(bucket_name, prefix) + matches = [obj for obj in objects if fnmatch.fnmatch(obj, pattern)] + + log.info( + f"Found {len(matches)} objects matching '{pattern}' in {bucket_name}/{prefix}" + ) + for obj in matches: + log.info(f"- {obj}") + + return matches diff --git a/tests/end_to_end/pytest.ini b/tests/end_to_end/pytest.ini index 2e8d4c9d69..372b5b8b10 100644 --- a/tests/end_to_end/pytest.ini +++ b/tests/end_to_end/pytest.ini @@ -9,7 +9,10 @@ markers = task_runner_basic: mark a test as a task runner basic test. task_runner_dockerized_ws: mark a test as a task runner dockerized workspace test. task_runner_basic_gandlf: mark a test as a task runner basic for GanDLF test. + task_runner_connectivity: mark a test as a connectivity test. + task_runner_with_s3: mark a test as a task runner with S3 test. federated_runtime_301_watermarking: mark a test as a federated runtime 301 watermarking test. straggler_tests: mark a test as a straggler test. + task_runner_fed_analytics: mark a test as a task runner analytics test. asyncio_mode=auto asyncio_default_fixture_loop_scope="function" diff --git a/tests/end_to_end/test_suites/task_runner_tests.py b/tests/end_to_end/test_suites/task_runner_tests.py index 83b51f7ca5..52fe7c00f6 100644 --- a/tests/end_to_end/test_suites/task_runner_tests.py +++ b/tests/end_to_end/test_suites/task_runner_tests.py @@ -57,7 +57,7 @@ def test_federation_via_dockerized_workspace(request, fx_federation_tr_dws): log.info(f"Model best aggregated score post {request.config.num_rounds} is {best_agg_score}") -@pytest.mark.task_runner_basic_connectivity +@pytest.mark.task_runner_connectivity def test_federation_connectivity(request, fx_federation_tr): """ Verify that the collaborator can ping the aggregator. If Ping successful, collaborator can start the training. diff --git a/tests/end_to_end/test_suites/tr_fed_analytics_tests.py b/tests/end_to_end/test_suites/tr_fed_analytics_tests.py new file mode 100644 index 0000000000..2300a216ea --- /dev/null +++ b/tests/end_to_end/test_suites/tr_fed_analytics_tests.py @@ -0,0 +1,70 @@ +# Copyright 2020-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import logging +import os + + +from tests.end_to_end.utils.tr_common_fixtures import ( + fx_federation_tr, +) +from tests.end_to_end.utils import federation_helper as fed_helper +import json +import tests.end_to_end.utils.constants as constants + +log = logging.getLogger(__name__) + +# write a fixture to update request.config.num_rounds to 1 +@pytest.fixture(scope="function") +def set_num_rounds(request): + """ + Fixture to set the number of rounds for the test. + Args: + request (Fixture): Pytest fixture + """ + # Set the number of rounds to 1 + log.info("Setting number of rounds to 1 for analytics test") + request.config.num_rounds = 1 + if "federated_analytics" not in request.config.model_name: + pytest.skip( + f"Model name {request.config.model_name} is not supported for this test. " + "Please use a different model name." + ) + + +@pytest.mark.task_runner_fed_analytics +def test_federation_analytics(request, set_num_rounds, fx_federation_tr): + """ + Test federation via native task runner. + Args: + request (Fixture): Pytest fixture + fx_federation_tr (Fixture): Pytest fixture for native task runner + """ + # Start the federation + assert fed_helper.run_federation(fx_federation_tr) + + # Verify the completion of the federation run + assert fed_helper.verify_federation_run_completion( + fx_federation_tr, + test_env=request.config.test_env, + num_rounds=request.config.num_rounds, + ), "Federation completion failed" + + # verify that results get saved in save/results.json + result_path = os.path.join( + fx_federation_tr.aggregator.workspace_path, + "save", + "result.json" + ) + assert os.path.exists(result_path), f"Results file {result_path} does not exist" + + with open(result_path, "r") as f: + results = f.read() + try: + results_json = json.loads(results) + except json.JSONDecodeError as e: + log.warning("Results file is not valid JSON. Raw content:\n%s", results) + raise e + + assert results, f"Results file {result_path} is empty" \ No newline at end of file diff --git a/tests/end_to_end/test_suites/tr_verifiable_dataset_tests.py b/tests/end_to_end/test_suites/tr_verifiable_dataset_tests.py new file mode 100644 index 0000000000..62aeb40beb --- /dev/null +++ b/tests/end_to_end/test_suites/tr_verifiable_dataset_tests.py @@ -0,0 +1,41 @@ +# Copyright 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import logging + +from tests.end_to_end.utils.tr_common_fixtures import fx_federation_tr +from tests.end_to_end.utils import federation_helper as fed_helper + +log = logging.getLogger(__name__) + + +@pytest.mark.task_runner_with_s3 +def test_federation_with_s3_bucket(request, fx_federation_tr): + """ + Test federation with S3 bucket. Model name - torch/histology_s3 + Steps: + 1. Start the minio server, create buckets for every collaborator. + 2. Download data using torch/histology dataloader and upload data to the buckets. + 3. Create a datasources.json file for each collaborator which will contain the S3 bucket and/or local datasources. + 4. Calculate hash for each collaborator's data (it generates hash.txt file under the data directory). + 5. Start the federation (internally the hash is verified as well). + 6. Verify the completion of the federation run. + 7. Verify the best aggregated score. + Args: + request (Fixture): Pytest fixture + fx_federation_tr (Fixture): Pytest fixture for native task runner + """ + # Start the federation + assert fed_helper.run_federation(fx_federation_tr) + + # Verify the completion of the federation run + assert fed_helper.verify_federation_run_completion( + fx_federation_tr, + test_env=request.config.test_env, + num_rounds=request.config.num_rounds, + time_for_each_round=300, + ), "Federation completion failed" + + best_agg_score = fed_helper.get_best_agg_score(fx_federation_tr.aggregator.tensor_db_file) + log.info(f"Model best aggregated score post {request.config.num_rounds} is {best_agg_score}") diff --git a/tests/end_to_end/utils/constants.py b/tests/end_to_end/utils/constants.py index c3e63de45f..69bb8728cf 100644 --- a/tests/end_to_end/utils/constants.py +++ b/tests/end_to_end/utils/constants.py @@ -14,6 +14,7 @@ class ModelName(Enum): KERAS_MNIST = "keras/mnist" KERAS_TORCH_MNIST = "keras/torch/mnist" TORCH_HISTOLOGY = "torch/histology" + TORCH_HISTOLOGY_S3 = "torch/histology_s3" TORCH_MNIST = "torch/mnist" TORCH_MNIST_EDEN_COMPRESSION = "torch/mnist_eden_compression" TORCH_MNIST_STRAGGLER_CHECK = "torch/mnist_straggler_check" @@ -21,6 +22,9 @@ class ModelName(Enum): GANDLF_SEG_TEST = "gandlf_seg_test" FLOWER_APP_PYTORCH = "flower-app-pytorch" NO_OP = "no-op" + KERAS_TENSORFLOW_MNIST = "keras/tensorflow/mnist" + FEDERATED_ANALYTICS_HISTOGRAM = "federated_analytics/histogram" + FEDERATED_ANALYTICS_SMOKERS_HEALTH = "federated_analytics/smokers_health" NUM_COLLABORATORS = 2 NUM_ROUNDS = 5 @@ -60,3 +64,13 @@ class ModelName(Enum): EXCEPTION = "Exception" AGG_METRIC_MODEL_ACCURACY_KEY = "aggregator/aggregated_model_validation/accuracy" COL_TLS_END_MSG = "TLS connection established." + +# For S3 and MinIO +MINIO_ROOT_USER = "minioadmin" +MINIO_ROOT_PASSWORD = "minioadmin" +MINIO_HOST = "localhost" +MINIO_PORT = 9000 +MINIO_CONSOLE_PORT = 9001 +MINIO_URL = f"http://{MINIO_HOST}:{MINIO_PORT}" +MINIO_CONSOLE_URL = f"http://{MINIO_HOST}:{MINIO_CONSOLE_PORT}" +MINIO_DATA_FOLDER = "minio_data" diff --git a/tests/end_to_end/utils/exceptions.py b/tests/end_to_end/utils/exceptions.py index e7c353eaa3..99e37e24ac 100644 --- a/tests/end_to_end/utils/exceptions.py +++ b/tests/end_to_end/utils/exceptions.py @@ -129,3 +129,28 @@ class FlowerAppException(Exception): class ProcessKillException(Exception): """Exception for process kill""" pass + + +class HashCalculationException(Exception): + """Exception for hash calculation of collaborator's data path""" + pass + + +class MinioServerStartException(Exception): + """Exception for minio server start""" + pass + + +class S3BucketCreationException(Exception): + """Exception for S3 bucket creation""" + pass + + +class DataDownloadException(Exception): + """Exception for data download""" + pass + + +class DataUploadToS3Exception(Exception): + """Exception for data upload to S3""" + pass diff --git a/tests/end_to_end/utils/federation_helper.py b/tests/end_to_end/utils/federation_helper.py index 99d1a6c6b5..63c7261939 100644 --- a/tests/end_to_end/utils/federation_helper.py +++ b/tests/end_to_end/utils/federation_helper.py @@ -18,6 +18,7 @@ import tests.end_to_end.utils.docker_helper as dh import tests.end_to_end.utils.exceptions as ex import tests.end_to_end.utils.interruption_helper as intr_helper +import tests.end_to_end.utils.s3_helper as s3_helper import tests.end_to_end.utils.ssh_helper as ssh from tests.end_to_end.models import collaborator as col_model from tests.end_to_end.utils.generate_report import convert_to_json @@ -271,13 +272,14 @@ def run_federation_for_dws(fed_obj, use_tls): return True -def verify_federation_run_completion(fed_obj, test_env, num_rounds): +def verify_federation_run_completion(fed_obj, test_env, num_rounds, time_for_each_round=100): """ Verify the completion of the process for all the participants Args: fed_obj (object): Federation fixture object test_env (str): Test environment num_rounds (int): Number of rounds + time_for_each_round (int): Time for each round (in seconds) Returns: list: List of response (True or False) for all the participants """ @@ -291,6 +293,7 @@ def verify_federation_run_completion(fed_obj, test_env, num_rounds): participant, num_rounds, num_collaborators=len(fed_obj.collaborators), + time_for_each_round=time_for_each_round, ) for participant in fed_obj.collaborators + [fed_obj.aggregator] ] @@ -569,7 +572,7 @@ def verify_cmd_output( raise Exception(f"{error_msg}: {error}") -def setup_collaborator(index, workspace_path, local_bind_path): +def setup_collaborator(index, workspace_path, local_bind_path, data_path=None, calc_hash=False, colab_bucket_mapping=None): """ Setup the collaborator Includes - creation of collaborator objects, starting docker container, importing workspace, creating collaborator @@ -577,13 +580,16 @@ def setup_collaborator(index, workspace_path, local_bind_path): index (int): Index of the collaborator. Starts with 1. workspace_path (str): Workspace path local_bind_path (str): Local bind path + data_path (str): Data path + calc_hash (bool): Flag to indicate if hash calculation is required + colab_bucket_mapping (dict): Mapping of collaborator and its datasources """ local_agg_ws_path = constants.AGG_WORKSPACE_PATH.format(local_bind_path) try: collaborator = col_model.Collaborator( collaborator_name=f"collaborator{index}", - data_directory_path=index, + data_directory_path=index if data_path is None else data_path, workspace_path=f"{workspace_path}/collaborator{index}/workspace", ) create_persistent_store(collaborator.name, local_bind_path) @@ -611,6 +617,28 @@ def setup_collaborator(index, workspace_path, local_bind_path): except Exception as e: raise ex.CollaboratorCreationException(f"Failed to create collaborator: {e}") + # Calculate the hash of collaborator datasource (specific to torch/histology_s3 model). + if calc_hash: + json_data = s3_helper.create_collaborator_datasource_json( + colab_bucket_mapping=colab_bucket_mapping, + ) + # Modify the data/collaborator{index}/datasources.json file + # to include the data path for the collaborator + data_source_file = os.path.join( + local_col_ws_path, "data", "datasources.json" + ) + with open(data_source_file, "w") as file: + json.dump(json_data, file, indent=4) + log.debug(f"Modified data source file for {collaborator.name}: {data_source_file}") + + try: + # Calculate hash for the collaborator + collaborator.calculate_hash() + except Exception as e: + raise ex.HashCalculationException( + f"Failed to calculate hash for {collaborator.name}: {e}" + ) + return collaborator diff --git a/tests/end_to_end/utils/s3_helper.py b/tests/end_to_end/utils/s3_helper.py new file mode 100644 index 0000000000..2dd7f74062 --- /dev/null +++ b/tests/end_to_end/utils/s3_helper.py @@ -0,0 +1,99 @@ +# Copyright 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import shutil +import logging +from pathlib import Path + +import tests.end_to_end.utils.constants as constants +import tests.end_to_end.utils.exceptions as ex + +log = logging.getLogger(__name__) + + +def create_collaborator_datasource_json(colab_bucket_mapping, endpoint=constants.MINIO_URL): + """ + Create a datasources.json file for a collaborator. + + Args: + colab_bucket_mapping (dict): Mapping of given collaborator with its datasources + endpoint (str): S3 endpoint URL + + Returns: + JSON object: JSON object representing the datasource configuration + """ + collaborator_name = colab_bucket_mapping["collaborator"] + buckets = colab_bucket_mapping["buckets"] + local_data_path = colab_bucket_mapping["local_data_path"] + index = int(''.join(filter(str.isdigit, collaborator_name))) + data = {} + + for i, bucket in enumerate(buckets, 1): + ds_key = f"s3_ds{i}" + data[ds_key] = { + "type": "s3", + "params": { + "access_key_env_name": "MINIO_ROOT_USER", + "endpoint": endpoint, + "secret_key_env_name": "MINIO_ROOT_PASSWORD", + "secret_name": f"vault_secret_name{i}", + "uri": f"s3://{bucket}/" + } + } + # Add local datasource for odd collaborators (collaborator index is odd) + if index is not None and index % 2 == 1: + data[f"local_ds{index}"] = { + "type": "local", + "params": { + "path": str(Path(local_data_path).relative_to(Path.cwd())) + } + } + + return data + + +def upload_data_to_s3(s3_obj, colab_bucket_mapping_list): + """ + Upload data to S3 buckets based on the provided mapping. + Args: + s3_obj (S3Helper): S3Helper object for S3 operations + colab_bucket_mapping_list (list): List of dictionaries containing collaborator and bucket mapping + Returns: + bool: True if upload was successful, raises DataUploadToS3Exception exception otherwise + """ + for colab in colab_bucket_mapping_list: + folder_path = Path(colab["local_data_path"]) + buckets = colab["buckets"] + if len(buckets) == 2: + # Split the folder contents equally for two buckets + all_items = sorted([item for item in folder_path.iterdir() if item.is_dir() or item.is_file()]) + mid = len(all_items) // 2 + split_items = [all_items[:mid], all_items[mid:]] + for i, bucket_name in enumerate(buckets): + temp_dir = folder_path / f"tmp_upload_{i+1}" + temp_dir.mkdir(exist_ok=True) + for item in split_items[i]: + dest = temp_dir / item.name + if item.is_dir(): + shutil.copytree(item, dest) + else: + shutil.copy2(item, dest) + try: + s3_obj.upload_directory(dir_path=temp_dir, bucket_name=bucket_name) + log.info(f"Uploaded data to bucket {bucket_name} from {temp_dir}") + except Exception as e: + raise ex.DataUploadToS3Exception( + f"Failed to upload data to bucket {bucket_name}. Error: {e}" + ) + shutil.rmtree(temp_dir) + else: + # Only one bucket, upload the whole folder + bucket_name = buckets[0] + try: + s3_obj.upload_directory(dir_path=folder_path, bucket_name=bucket_name) + log.info(f"Uploaded data to bucket {bucket_name} from {folder_path}") + except Exception as e: + raise ex.DataUploadToS3Exception( + f"Failed to upload data to bucket {bucket_name}. Error: {e}" + ) + return True diff --git a/tests/end_to_end/utils/tr_workspace.py b/tests/end_to_end/utils/tr_workspace.py index 10715fa276..7dd52d02d3 100644 --- a/tests/end_to_end/utils/tr_workspace.py +++ b/tests/end_to_end/utils/tr_workspace.py @@ -6,12 +6,14 @@ import logging import os from pathlib import Path +import importlib import tests.end_to_end.utils.constants as constants import tests.end_to_end.utils.exceptions as ex import tests.end_to_end.utils.federation_helper as fh +import tests.end_to_end.utils.s3_helper as s3_helper import tests.end_to_end.utils.ssh_helper as ssh -from tests.end_to_end.models import aggregator as agg_model, model_owner as mo_model +from tests.end_to_end.models import aggregator as agg_model, model_owner as mo_model, s3_bucket as s3_model import tests.end_to_end.utils.docker_helper as dh log = logging.getLogger(__name__) @@ -97,6 +99,9 @@ def create_tr_workspace(request, eval_scope=False): tuple : A named tuple containing the objects for model owner, aggregator, and collaborators. """ + if request.config.model_name.lower() == constants.ModelName.TORCH_HISTOLOGY_S3.value: + colab_bucket_mapping_list = prepare_data_for_s3(request) + # get details of model owner, collaborators, and aggregator from common # workspace creation function workspace_path, local_bind_path, agg_domain_name, model_owner, plan_path, agg_workspace_path, initial_model_path = common_workspace_creation(request, eval_scope) @@ -135,15 +140,34 @@ def create_tr_workspace(request, eval_scope=False): collaborators = [] executor = concurrent.futures.ThreadPoolExecutor() - futures = [ - executor.submit( - fh.setup_collaborator, - index, - workspace_path=workspace_path, - local_bind_path=local_bind_path, - ) - for index in range(1, request.config.num_collaborators+1) - ] + # In case of torch/histology_s3, we need to pass the data path, flag to calculate hash + # and bucket mapping to the setup_collaborator function + if request.config.model_name.lower() == constants.ModelName.TORCH_HISTOLOGY_S3.value: + futures = [ + executor.submit( + fh.setup_collaborator, + index, + workspace_path=workspace_path, + local_bind_path=local_bind_path, + data_path="data", + calc_hash=True, + colab_bucket_mapping=next( + (item for item in colab_bucket_mapping_list if item["collaborator"] == f"collaborator{index}"), + None + ), + ) + for index in range(1, request.config.num_collaborators+1) + ] + else: + futures = [ + executor.submit( + fh.setup_collaborator, + index, + workspace_path=workspace_path, + local_bind_path=local_bind_path, + ) + for index in range(1, request.config.num_collaborators+1) + ] collaborators = [f.result() for f in futures] # Data setup requires total no of collaborators, thus keeping the function call @@ -373,3 +397,146 @@ def create_tr_dws_workspace(request, eval_scope=False): local_bind_path=local_bind_path, model_name=request.config.model_name, ) + + +def prepare_data_for_s3(request): + """ + Prepare data for S3. Includes starting minio server, creating bucket, and uploading data. + Args: + request (object): Pytest request object. + Returns: + dict: A dictionary containing the bucket mapping for each collaborator. + Example - + [ + {'collaborator': 'collaborator1', 'local_data_path': '/home/azureuser/openfl/data/1', 'buckets': ['bucket-1']}, + {'collaborator': 'collaborator2', 'local_data_path': '/home/azureuser/openfl/data/2', 'buckets': ['bucket-2-01', 'bucket-2-02']} + ] + """ + s3_obj = s3_model.S3Bucket() + + num_collaborators = request.config.num_collaborators + + # Import the dataloader module for torch/histology to download the data + # As the folder name contains hyphen, we need to use importlib to import the module + dataloader_module = importlib.import_module("openfl-workspace.torch.histology.src.dataloader") + + # Download the data for torch/histology in current folder as internally it uses the current folder as data path + try: + log.info(f"Downloading data for {constants.ModelName.TORCH_HISTOLOGY_S3.value}") + dataloader_module.HistologyDataset() + log.info("Download completed") + except Exception as e: + raise ex.DataDownloadException( + f"Failed to download data for {constants.ModelName.TORCH_HISTOLOGY_S3.value}. Error: {e}" + ) + + # Distibute the downloaded data/folders among the collaborators + hist_data_path = Path.cwd().absolute() / 'data' # We cannot change it, as the data loader is using this path without any input + try: + distribute_data_to_collaborators(num_collaborators, hist_data_path) + except Exception as e: + raise ex.DataSetupException( + f"Failed to distribute data to collaborators. Error: {e}" + ) + + # Start minio server, create S3 buckets and upload the data to S3 + try: + s3_obj.start_minio_server( + data_dir=os.path.join(Path().home(), request.config.results_dir, constants.MINIO_DATA_FOLDER) + ) + log.info("Started minio server") + except Exception as e: + raise ex.MinioServerStartException( + f"Failed to start minio server. Error: {e}" + ) + + # Create the buckets for each collaborator + # The bucket name will be bucket-1, bucket-2, ..., bucket-n + # where n is the number of collaborators + colab_bucket_mapping_list = [] + bucket_name = None + for index in range(1, num_collaborators + 1): + try: + folder_path = hist_data_path / str(index) + if index % 2 == 0: + bucket_list = [] + for suffix in ["01", "02"]: + bucket_name = f"bucket-{index}-{suffix}" + s3_obj.create_bucket(bucket_name=bucket_name) + log.info(f"Created bucket {bucket_name}") + bucket_list.append(bucket_name) + colab_bucket_mapping_list.append({ + "collaborator": f"collaborator{index}", + "local_data_path": str(folder_path), + "buckets": bucket_list + }) + else: + bucket_name = f"bucket-{index}" + s3_obj.create_bucket(bucket_name=bucket_name) + log.info(f"Created bucket {bucket_name}") + colab_bucket_mapping_list.append({ + "collaborator": f"collaborator{index}", + "local_data_path": str(folder_path), + "buckets": [bucket_name] + }) + except Exception as e: + raise ex.S3BucketCreationException( + f"Failed to create bucket {bucket_name} for collaborator{index}. Error: {e}" + ) + + log.info(f"Bucket mapping: {colab_bucket_mapping_list}") + + # List the buckets to verify + s3_obj.list_buckets() + + # Copy the data to the S3 buckets by equally distributing the data among the collaborators + s3_helper.upload_data_to_s3(s3_obj, colab_bucket_mapping_list) + + return colab_bucket_mapping_list + + +def distribute_data_to_collaborators(num_collaborators, data_path): + """ + Distribute the data among the collaborators uniformly. + Example: Assuming num_collaborators is 3 + If data_path has folder Kather_texture_2016_image_tiles_5000 (torch/histology) which further has 8 subfolders, + then the data will be distributed as: + collaborator1: 1 / first 3 subfolders + collaborator2: 2 / next 3 subfolders + collaborator3: 3 / last 2 subfolders + If data_path itself has multiple folders say 8, then the data will be distributed as: + collaborator1: 1 / first 3 folders + collaborator2: 2 / next 3 folders + collaborator3: 3 / last 2 folders + Args: + num_collaborators (int): Number of collaborators. + data_path (str): Path to the data directory. + Raises: + Exception: If the data distribution fails. + """ + # If data_path has only one folder, go inside it and use its subfolders + all_entries = [f for f in data_path.iterdir() if f.is_dir()] + if len(all_entries) == 1: + # Use subfolders inside the single folder + all_folders = [f for f in all_entries[0].iterdir() if f.is_dir()] + else: + all_folders = all_entries + all_folders.sort() # For deterministic split + + num_folders = len(all_folders) + folders_per_collab = [num_folders // num_collaborators] * num_collaborators + + # Distribute the remainder (if any) to the first few collaborators + for i in range(num_folders % num_collaborators): + folders_per_collab[i] += 1 + + start = 0 + for index in range(1, num_collaborators + 1): + collaborator_data_path = data_path / str(index) + collaborator_data_path.mkdir(parents=True, exist_ok=True) + end = start + folders_per_collab[index - 1] + for folder in all_folders[start:end]: + # Move or copy the folder to the collaborator's directory + # Here we move; use shutil.copytree if you want to copy instead + folder.rename(collaborator_data_path / folder.name) + start = end