From 9133c830df36936256cd29415a893a01497fc105 Mon Sep 17 00:00:00 2001 From: Kanglan Tang Date: Wed, 21 May 2025 21:57:17 +0000 Subject: [PATCH 1/4] Add clang-format composite action --- ci_clangformat/.clang-format.default | 6 ++ ci_clangformat/README.md | 1 + ci_clangformat/action.yaml | 91 ++++++++++++++++++++++++++++ 3 files changed, 98 insertions(+) create mode 100644 ci_clangformat/.clang-format.default create mode 100644 ci_clangformat/README.md create mode 100644 ci_clangformat/action.yaml diff --git a/ci_clangformat/.clang-format.default b/ci_clangformat/.clang-format.default new file mode 100644 index 0000000..720f2f8 --- /dev/null +++ b/ci_clangformat/.clang-format.default @@ -0,0 +1,6 @@ +BasedOnStyle: Google +Language: Cpp +PointerBindsToType: true +SortIncludes: Never +AlignTrailingComments: + Kind: Always diff --git a/ci_clangformat/README.md b/ci_clangformat/README.md new file mode 100644 index 0000000..eee712d --- /dev/null +++ b/ci_clangformat/README.md @@ -0,0 +1 @@ +# CI Clang-format diff --git a/ci_clangformat/action.yaml b/ci_clangformat/action.yaml new file mode 100644 index 0000000..1d53c52 --- /dev/null +++ b/ci_clangformat/action.yaml @@ -0,0 +1,91 @@ +# Copyright 2024 Google LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +name: "Check clang-format" +description: 'Action to run clang-format on changed files' +inputs: + clang_format_version: + description: 'The clang-format version to use (e.g., "17", "18").' + required: true + default: "17.0.6" + branch_name: + description: 'The name of the branch (e.g., "main") used for fetching comparisons.' + required: true + default: 'main' + +# outputs: +# changes_detected: +# description: 'True if clang-format detected formatting changes.' +# value: ${{ steps.check-format.outputs.changes_detected }} + +runs: + using: "composite" + steps: + - name: "Checking out repository" + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: "Fetch HEAD of ${{ inputs.branch_name }} branch" + shell: bash + run: git fetch origin ${{ inputs.branch_name }} --depth=1 + # - name: "Determine and install clang-format" + # shell: bash + # id: install-clang-format + # run: | + # CLANG_VERSION="${{ inputs.clang_format_version }}" + # CLANG_FORMAT_NAME="" + + # if [ "$CLANG_VERSION" == "latest" ]; then + # CLANG_FORMAT_NAME="clang-format" + # else + # CLANG_FORMAT_NAME="clang-format-${CLANG_VERSION}" + # fi + + # if command -v "$CLANG_FORMAT_NAME" &> /dev/null; then + # echo "::notice::'$CLANG_FORMAT_NAME' is already installed." + # else + # echo "::warning::'$CLANG_FORMAT_NAME' not found. Installing..." + # sudo apt-get update + # if ! sudo apt-get install -y "$CLANG_FORMAT_NAME"; then + # echo "::error::Failed to install '$CLANG_FORMAT_NAME'. Please check the version or try 'latest'." + # exit 1 + # fi + # fi + # echo "clang_format_exe=$CLANG_FORMAT_NAME" >> "$GITHUB_OUTPUT" + # - name: "Prepare .clang-format file" + # shell: bash + # id: prepare-config + # run: | + # REPO_CLANG_FORMAT=".clang-format" + # ACTION_DEFAULT_CLANG_FORMAT="${{ github.action_path }}/.clang-format.default" + + # if [ -f "$REPO_CLANG_FORMAT" ]; then + # echo "::notice::Using repository's .clang-format file." + # # cat "$REPO_CLANG_FORMAT" + # else [ -f "$ACTION_DEFAULT_CLANG_FORMAT" ]; then + # echo "::notice::Repository does not have a .clang-format file. Using the action's default." + # cp "$ACTION_DEFAULT_CLANG_FORMAT" "$REPO_CLANG_FORMAT" + # fi + - name: Run clang-format check + id: check-format + shell: bash + run: | + REPO_CLANG_FORMAT=".clang-format" + ACTION_DEFAULT_CLANG_FORMAT="${{ github.action_path }}/.clang-format.default" + + if [ -f "$REPO_CLANG_FORMAT" ]; then + echo "::notice::Using repository's .clang-format file." + # pipx installs and runs clang-format with a specific version. + pipx run clang-format==${{ inputs.clang_format_version }} --dry-run --Werror --verbose $(git diff --name-only origin/main HEAD -- '*.cc' '*.h') + else + echo "::notice::Repository does not have a .clang-format file. Using the action's default." + pipx run clang-format==${{ inputs.clang_format_version }} -style=file:$ACTION_DEFAULT_CLANG_FORMAT --dry-run --Werror --verbose $(git diff --name-only origin/main HEAD -- '*.cc' '*.h') + fi From 8d1e3eaf165a3c2e63a3e68006dda5b82651756d Mon Sep 17 00:00:00 2001 From: Kanglan Tang Date: Thu, 22 May 2025 17:39:17 +0000 Subject: [PATCH 2/4] Update README.md for clang-format composite action --- ci_clangformat/README.md | 29 ++++++++++++++++++++++ ci_clangformat/action.yaml | 49 +++----------------------------------- 2 files changed, 32 insertions(+), 46 deletions(-) diff --git a/ci_clangformat/README.md b/ci_clangformat/README.md index eee712d..8fd5690 100644 --- a/ci_clangformat/README.md +++ b/ci_clangformat/README.md @@ -1 +1,30 @@ # CI Clang-format + +This composite action helps maintain consistent C/C++ code style by running +`clang-format` on modified files in your pull requests. It checks for +formatting violations and will cause the workflow to fail if any issues are +found, ensuring code quality before merging. + +The action uses your .clang-format style file if present in the repository +root; otherwise, it will use the .clang-format.default under this folder. + +This action offers the following configuration through its inputs: +* `clang_format_version`: Choose the exact clang-format version to use, + with `20.1.5` as the default to align with recent stable releases. +* `branch_name`: Specify the name of your repository branch (`main` by + default) for comparing changes within the pull requests. + +## Resolving Formatting Failures +If a workflow run fails due to formatting violations, you're expected to +fix the issues locally. Simply run `clang-format` on the problematic +files, e.g., using +`pipx run clang-format==20.1.5 --style=file --Werror -i `, +and then commit the formatted code to your pull request. + +## Pipx Requirement +This action leverages `pipx` to reliably install and run specific +`clang-format` versions, ensuring consistent behavior across different +environments. `pipx` is generally pre-installed on GitHub Actions hosted +runners (you can verify available tools on the runner images [doc](https://github.com/actions/runner-images?tab=readme-ov-file#available-images)). +If `pipx` does not exist, you'll need to include a step to install it +in your workflow's running environment. diff --git a/ci_clangformat/action.yaml b/ci_clangformat/action.yaml index 1d53c52..39647b9 100644 --- a/ci_clangformat/action.yaml +++ b/ci_clangformat/action.yaml @@ -15,19 +15,14 @@ name: "Check clang-format" description: 'Action to run clang-format on changed files' inputs: clang_format_version: - description: 'The clang-format version to use (e.g., "17", "18").' + description: 'The clang-format version to use.' required: true - default: "17.0.6" + default: "20.1.5" branch_name: - description: 'The name of the branch (e.g., "main") used for fetching comparisons.' + description: 'The repository branch used for fetching comparisons.' required: true default: 'main' -# outputs: -# changes_detected: -# description: 'True if clang-format detected formatting changes.' -# value: ${{ steps.check-format.outputs.changes_detected }} - runs: using: "composite" steps: @@ -36,44 +31,6 @@ runs: - name: "Fetch HEAD of ${{ inputs.branch_name }} branch" shell: bash run: git fetch origin ${{ inputs.branch_name }} --depth=1 - # - name: "Determine and install clang-format" - # shell: bash - # id: install-clang-format - # run: | - # CLANG_VERSION="${{ inputs.clang_format_version }}" - # CLANG_FORMAT_NAME="" - - # if [ "$CLANG_VERSION" == "latest" ]; then - # CLANG_FORMAT_NAME="clang-format" - # else - # CLANG_FORMAT_NAME="clang-format-${CLANG_VERSION}" - # fi - - # if command -v "$CLANG_FORMAT_NAME" &> /dev/null; then - # echo "::notice::'$CLANG_FORMAT_NAME' is already installed." - # else - # echo "::warning::'$CLANG_FORMAT_NAME' not found. Installing..." - # sudo apt-get update - # if ! sudo apt-get install -y "$CLANG_FORMAT_NAME"; then - # echo "::error::Failed to install '$CLANG_FORMAT_NAME'. Please check the version or try 'latest'." - # exit 1 - # fi - # fi - # echo "clang_format_exe=$CLANG_FORMAT_NAME" >> "$GITHUB_OUTPUT" - # - name: "Prepare .clang-format file" - # shell: bash - # id: prepare-config - # run: | - # REPO_CLANG_FORMAT=".clang-format" - # ACTION_DEFAULT_CLANG_FORMAT="${{ github.action_path }}/.clang-format.default" - - # if [ -f "$REPO_CLANG_FORMAT" ]; then - # echo "::notice::Using repository's .clang-format file." - # # cat "$REPO_CLANG_FORMAT" - # else [ -f "$ACTION_DEFAULT_CLANG_FORMAT" ]; then - # echo "::notice::Repository does not have a .clang-format file. Using the action's default." - # cp "$ACTION_DEFAULT_CLANG_FORMAT" "$REPO_CLANG_FORMAT" - # fi - name: Run clang-format check id: check-format shell: bash From 66eda9cc2148189c55403b75b01253674086ac97 Mon Sep 17 00:00:00 2001 From: Kanglan Tang Date: Thu, 22 May 2025 22:18:37 +0000 Subject: [PATCH 3/4] Replace pipx with uv as uv is pre-installed in our container images --- ci_clangformat/README.md | 13 ++++++------- ci_clangformat/action.yaml | 11 +++++++---- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/ci_clangformat/README.md b/ci_clangformat/README.md index 8fd5690..9637e0d 100644 --- a/ci_clangformat/README.md +++ b/ci_clangformat/README.md @@ -18,13 +18,12 @@ This action offers the following configuration through its inputs: If a workflow run fails due to formatting violations, you're expected to fix the issues locally. Simply run `clang-format` on the problematic files, e.g., using -`pipx run clang-format==20.1.5 --style=file --Werror -i `, +`uvx clang-format==20.1.5 -i --verbose --style=file `, and then commit the formatted code to your pull request. -## Pipx Requirement -This action leverages `pipx` to reliably install and run specific +## UV Requirement +This action leverages `uv` to reliably install and run specific `clang-format` versions, ensuring consistent behavior across different -environments. `pipx` is generally pre-installed on GitHub Actions hosted -runners (you can verify available tools on the runner images [doc](https://github.com/actions/runner-images?tab=readme-ov-file#available-images)). -If `pipx` does not exist, you'll need to include a step to install it -in your workflow's running environment. +environments. `uvx` is a convenience alias that calls `uv tool run`. +If `uv` does not exist, you'll need to include a step to [install](https://docs.astral.sh/uv/getting-started/installation/) +it in your workflow's running environment. diff --git a/ci_clangformat/action.yaml b/ci_clangformat/action.yaml index 39647b9..f7baadd 100644 --- a/ci_clangformat/action.yaml +++ b/ci_clangformat/action.yaml @@ -30,7 +30,10 @@ runs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: "Fetch HEAD of ${{ inputs.branch_name }} branch" shell: bash - run: git fetch origin ${{ inputs.branch_name }} --depth=1 + run: | + # Silence safe.directory warnings when a job is running within a container + /usr/bin/git config --global --add safe.directory '*' + /usr/bin/git fetch origin ${{ inputs.branch_name }} --depth=1 - name: Run clang-format check id: check-format shell: bash @@ -40,9 +43,9 @@ runs: if [ -f "$REPO_CLANG_FORMAT" ]; then echo "::notice::Using repository's .clang-format file." - # pipx installs and runs clang-format with a specific version. - pipx run clang-format==${{ inputs.clang_format_version }} --dry-run --Werror --verbose $(git diff --name-only origin/main HEAD -- '*.cc' '*.h') + # uvx (an alias for `uv tool run`) installs and runs clang-format with a specific version. + uvx clang-format==${{ inputs.clang_format_version }} --dry-run --Werror --verbose $(git diff --name-only origin/main HEAD -- '*.cc' '*.h') else echo "::notice::Repository does not have a .clang-format file. Using the action's default." - pipx run clang-format==${{ inputs.clang_format_version }} -style=file:$ACTION_DEFAULT_CLANG_FORMAT --dry-run --Werror --verbose $(git diff --name-only origin/main HEAD -- '*.cc' '*.h') + uvx clang-format==${{ inputs.clang_format_version }} -style=file:$ACTION_DEFAULT_CLANG_FORMAT --dry-run --Werror --verbose $(git diff --name-only origin/main HEAD -- '*.cc' '*.h') fi From 20bf08e54851d46c55b56bc6581e290a77fef88c Mon Sep 17 00:00:00 2001 From: Kanglan Tang Date: Tue, 27 May 2025 17:26:47 +0000 Subject: [PATCH 4/4] Update clang-format action and add a test --- .github/workflows/check-clang-format.yaml | 29 + ci_clangformat/README.md | 2 - ci_clangformat/action.yaml | 35 +- tests/ci_clangformat/absl_status_casters.h | 213 ++ tests/ci_clangformat/callback.cc | 173 ++ tests/ci_clangformat/callback.h | 87 + tests/ci_clangformat/config.cc | 348 +++ tests/ci_clangformat/config.h | 34 + tests/ci_clangformat/custom_call_sharding.cc | 346 +++ tests/ci_clangformat/custom_call_sharding.h | 28 + tests/ci_clangformat/dlpack.cc | 503 ++++ tests/ci_clangformat/dlpack.h | 58 + tests/ci_clangformat/dlpack_support.cc | 223 ++ tests/ci_clangformat/dlpack_support.h | 30 + tests/ci_clangformat/ffi.cc | 374 +++ tests/ci_clangformat/ffi.h | 152 ++ tests/ci_clangformat/ffi_helpers.h | 217 ++ tests/ci_clangformat/guard_lib.cc | 197 ++ tests/ci_clangformat/guard_lib.h | 115 + tests/ci_clangformat/ifrt_proxy.cc | 162 ++ tests/ci_clangformat/ifrt_proxy.h | 31 + tests/ci_clangformat/jax_jit.cc | 495 ++++ tests/ci_clangformat/jax_jit.h | 266 ++ tests/ci_clangformat/kernel_helpers.h | 50 + .../ci_clangformat/kernel_nanobind_helpers.h | 72 + tests/ci_clangformat/mlir.cc | 236 ++ tests/ci_clangformat/mlir.h | 28 + tests/ci_clangformat/nb_class_ptr.h | 59 + tests/ci_clangformat/pjit.cc | 1401 +++++++++++ tests/ci_clangformat/pjit.h | 27 + tests/ci_clangformat/pmap_lib.cc | 1141 +++++++++ tests/ci_clangformat/pmap_lib.h | 33 + tests/ci_clangformat/py_array.cc | 2137 +++++++++++++++++ tests/ci_clangformat/py_array.h | 362 +++ tests/ci_clangformat/py_client.cc | 1021 ++++++++ tests/ci_clangformat/py_client.h | 256 ++ tests/ci_clangformat/py_client_cpu.cc | 243 ++ tests/ci_clangformat/py_client_cpu.h | 28 + .../ci_clangformat/py_compile_only_client.cc | 145 ++ tests/ci_clangformat/py_compile_only_client.h | 45 + tests/ci_clangformat/py_device.cc | 350 +++ tests/ci_clangformat/py_device.h | 83 + tests/ci_clangformat/py_device_list.cc | 482 ++++ tests/ci_clangformat/py_device_list.h | 142 ++ tests/ci_clangformat/py_executable.cc | 427 ++++ tests/ci_clangformat/py_executable.h | 246 ++ tests/ci_clangformat/py_host_callback.cc | 259 ++ tests/ci_clangformat/py_host_callback.h | 119 + tests/ci_clangformat/py_memory_space.cc | 102 + tests/ci_clangformat/py_memory_space.h | 65 + tests/ci_clangformat/py_program.cc | 301 +++ tests/ci_clangformat/py_program.h | 27 + tests/ci_clangformat/py_socket_transfer.cc | 420 ++++ tests/ci_clangformat/py_socket_transfer.h | 26 + tests/ci_clangformat/py_values.cc | 1097 +++++++++ tests/ci_clangformat/py_values.h | 161 ++ tests/ci_clangformat/python_ref_manager.cc | 106 + tests/ci_clangformat/python_ref_manager.h | 108 + tests/ci_clangformat/pytree.cc | 1831 ++++++++++++++ tests/ci_clangformat/pytree.h | 408 ++++ tests/ci_clangformat/sdy.cc | 140 ++ tests/ci_clangformat/sdy.h | 28 + tests/ci_clangformat/sharded_device_array.h | 216 ++ tests/ci_clangformat/sharding.cc | 396 +++ tests/ci_clangformat/sharding.h | 241 ++ tests/ci_clangformat/to_ifrt_sharding.cc | 141 ++ tests/ci_clangformat/to_ifrt_sharding.h | 61 + tests/ci_clangformat/traceback.cc | 357 +++ tests/ci_clangformat/traceback.h | 109 + tests/ci_clangformat/util.cc | 85 + tests/ci_clangformat/util.h | 34 + tests/ci_clangformat/utils.cc | 300 +++ tests/ci_clangformat/weakref_lru_cache.cc | 416 ++++ tests/ci_clangformat/xla.cc | 984 ++++++++ tests/ci_clangformat/xla_compiler.cc | 1451 +++++++++++ tests/ci_clangformat/xla_compiler.h | 28 + 76 files changed, 23131 insertions(+), 18 deletions(-) create mode 100644 .github/workflows/check-clang-format.yaml create mode 100644 tests/ci_clangformat/absl_status_casters.h create mode 100644 tests/ci_clangformat/callback.cc create mode 100644 tests/ci_clangformat/callback.h create mode 100644 tests/ci_clangformat/config.cc create mode 100644 tests/ci_clangformat/config.h create mode 100644 tests/ci_clangformat/custom_call_sharding.cc create mode 100644 tests/ci_clangformat/custom_call_sharding.h create mode 100644 tests/ci_clangformat/dlpack.cc create mode 100644 tests/ci_clangformat/dlpack.h create mode 100644 tests/ci_clangformat/dlpack_support.cc create mode 100644 tests/ci_clangformat/dlpack_support.h create mode 100644 tests/ci_clangformat/ffi.cc create mode 100644 tests/ci_clangformat/ffi.h create mode 100644 tests/ci_clangformat/ffi_helpers.h create mode 100644 tests/ci_clangformat/guard_lib.cc create mode 100644 tests/ci_clangformat/guard_lib.h create mode 100644 tests/ci_clangformat/ifrt_proxy.cc create mode 100644 tests/ci_clangformat/ifrt_proxy.h create mode 100644 tests/ci_clangformat/jax_jit.cc create mode 100644 tests/ci_clangformat/jax_jit.h create mode 100644 tests/ci_clangformat/kernel_helpers.h create mode 100644 tests/ci_clangformat/kernel_nanobind_helpers.h create mode 100644 tests/ci_clangformat/mlir.cc create mode 100644 tests/ci_clangformat/mlir.h create mode 100644 tests/ci_clangformat/nb_class_ptr.h create mode 100644 tests/ci_clangformat/pjit.cc create mode 100644 tests/ci_clangformat/pjit.h create mode 100644 tests/ci_clangformat/pmap_lib.cc create mode 100644 tests/ci_clangformat/pmap_lib.h create mode 100644 tests/ci_clangformat/py_array.cc create mode 100644 tests/ci_clangformat/py_array.h create mode 100644 tests/ci_clangformat/py_client.cc create mode 100644 tests/ci_clangformat/py_client.h create mode 100644 tests/ci_clangformat/py_client_cpu.cc create mode 100644 tests/ci_clangformat/py_client_cpu.h create mode 100644 tests/ci_clangformat/py_compile_only_client.cc create mode 100644 tests/ci_clangformat/py_compile_only_client.h create mode 100644 tests/ci_clangformat/py_device.cc create mode 100644 tests/ci_clangformat/py_device.h create mode 100644 tests/ci_clangformat/py_device_list.cc create mode 100644 tests/ci_clangformat/py_device_list.h create mode 100644 tests/ci_clangformat/py_executable.cc create mode 100644 tests/ci_clangformat/py_executable.h create mode 100644 tests/ci_clangformat/py_host_callback.cc create mode 100644 tests/ci_clangformat/py_host_callback.h create mode 100644 tests/ci_clangformat/py_memory_space.cc create mode 100644 tests/ci_clangformat/py_memory_space.h create mode 100644 tests/ci_clangformat/py_program.cc create mode 100644 tests/ci_clangformat/py_program.h create mode 100644 tests/ci_clangformat/py_socket_transfer.cc create mode 100644 tests/ci_clangformat/py_socket_transfer.h create mode 100644 tests/ci_clangformat/py_values.cc create mode 100644 tests/ci_clangformat/py_values.h create mode 100644 tests/ci_clangformat/python_ref_manager.cc create mode 100644 tests/ci_clangformat/python_ref_manager.h create mode 100644 tests/ci_clangformat/pytree.cc create mode 100644 tests/ci_clangformat/pytree.h create mode 100644 tests/ci_clangformat/sdy.cc create mode 100644 tests/ci_clangformat/sdy.h create mode 100644 tests/ci_clangformat/sharded_device_array.h create mode 100644 tests/ci_clangformat/sharding.cc create mode 100644 tests/ci_clangformat/sharding.h create mode 100644 tests/ci_clangformat/to_ifrt_sharding.cc create mode 100644 tests/ci_clangformat/to_ifrt_sharding.h create mode 100644 tests/ci_clangformat/traceback.cc create mode 100644 tests/ci_clangformat/traceback.h create mode 100644 tests/ci_clangformat/util.cc create mode 100644 tests/ci_clangformat/util.h create mode 100644 tests/ci_clangformat/utils.cc create mode 100644 tests/ci_clangformat/weakref_lru_cache.cc create mode 100644 tests/ci_clangformat/xla.cc create mode 100644 tests/ci_clangformat/xla_compiler.cc create mode 100644 tests/ci_clangformat/xla_compiler.h diff --git a/.github/workflows/check-clang-format.yaml b/.github/workflows/check-clang-format.yaml new file mode 100644 index 0000000..e632167 --- /dev/null +++ b/.github/workflows/check-clang-format.yaml @@ -0,0 +1,29 @@ +# A workflow to test clang-format composite action + +name: Test CI Clang-format Action +# Run on pull_request +on: + pull_request: + paths: + - ci_clangformat/** + - tests/ci_clangformat/** + - .github/workflows/check-clang-format.yaml + branches: + - main +defaults: + run: + shell: bash +jobs: + run-clang-format: + runs-on: "linux-x86-n2-16" + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" + timeout-minutes: 10 + if: | + github.event.sender.type == 'User' || + contains(github.event.pull_request.body, 'RUN_CLANG_FORMAT') + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false + - name: "Run clang-format action" + uses: ./ci_clangformat/ diff --git a/ci_clangformat/README.md b/ci_clangformat/README.md index 9637e0d..f02260d 100644 --- a/ci_clangformat/README.md +++ b/ci_clangformat/README.md @@ -11,8 +11,6 @@ root; otherwise, it will use the .clang-format.default under this folder. This action offers the following configuration through its inputs: * `clang_format_version`: Choose the exact clang-format version to use, with `20.1.5` as the default to align with recent stable releases. -* `branch_name`: Specify the name of your repository branch (`main` by - default) for comparing changes within the pull requests. ## Resolving Formatting Failures If a workflow run fails due to formatting violations, you're expected to diff --git a/ci_clangformat/action.yaml b/ci_clangformat/action.yaml index f7baadd..0c902e1 100644 --- a/ci_clangformat/action.yaml +++ b/ci_clangformat/action.yaml @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,34 +18,37 @@ inputs: description: 'The clang-format version to use.' required: true default: "20.1.5" - branch_name: - description: 'The repository branch used for fetching comparisons.' - required: true - default: 'main' runs: using: "composite" steps: - name: "Checking out repository" uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: "Fetch HEAD of ${{ inputs.branch_name }} branch" - shell: bash - run: | - # Silence safe.directory warnings when a job is running within a container - /usr/bin/git config --global --add safe.directory '*' - /usr/bin/git fetch origin ${{ inputs.branch_name }} --depth=1 + with: + fetch-depth: 0 # Fetch full history for accurate diffing - name: Run clang-format check id: check-format shell: bash + env: + REPO_CLANG_FORMAT: .clang-format + ACTION_DEFAULT_CLANG_FORMAT: ${{ github.action_path }}.clang-format.default + CLANG_FORMAT_VERSION: ${{ inputs.clang_format_version }} + TARGET_BRANCH: ${{ github.base_ref }} run: | - REPO_CLANG_FORMAT=".clang-format" - ACTION_DEFAULT_CLANG_FORMAT="${{ github.action_path }}/.clang-format.default" + FILE_PATTERNS="*.cc *.h" + CLANG_FORMAT_COMMON_ARGS="--dry-run --Werror --verbose" + + # Add this line to resolve the dubious ownership error + git config --global --add safe.directory "$GITHUB_WORKSPACE" + + # Compare PR head against the base branch to find changed files. + GIT_DIFF_CMD="git diff -z --name-only --diff-filter=d origin/$TARGET_BRANCH HEAD -- $FILE_PATTERNS" if [ -f "$REPO_CLANG_FORMAT" ]; then echo "::notice::Using repository's .clang-format file." # uvx (an alias for `uv tool run`) installs and runs clang-format with a specific version. - uvx clang-format==${{ inputs.clang_format_version }} --dry-run --Werror --verbose $(git diff --name-only origin/main HEAD -- '*.cc' '*.h') + $GIT_DIFF_CMD | xargs -0 uvx clang-format==$CLANG_FORMAT_VERSION $CLANG_FORMAT_COMMON_ARGS else - echo "::notice::Repository does not have a .clang-format file. Using the action's default." - uvx clang-format==${{ inputs.clang_format_version }} -style=file:$ACTION_DEFAULT_CLANG_FORMAT --dry-run --Werror --verbose $(git diff --name-only origin/main HEAD -- '*.cc' '*.h') + echo "::notice::Repository does not have a .clang-format file. Using the action's default under $ACTION_DEFAULT_CLANG_FORMAT." + $GIT_DIFF_CMD | xargs -0 uvx clang-format==$CLANG_FORMAT_VERSION -style=file:$ACTION_DEFAULT_CLANG_FORMAT $CLANG_FORMAT_COMMON_ARGS fi diff --git a/tests/ci_clangformat/absl_status_casters.h b/tests/ci_clangformat/absl_status_casters.h new file mode 100644 index 0000000..e3761f6 --- /dev/null +++ b/tests/ci_clangformat/absl_status_casters.h @@ -0,0 +1,213 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_ABSL_STATUS_CASTERS_H_ +#define JAXLIB_ABSL_STATUS_CASTERS_H_ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" + +namespace jax { + +// C++ -> Python caster helpers. +// +// Failing statuses become Python exceptions; OK Status() becomes None. +// +// For example: +// +// - Functions without arguments: +// m.def("my_func", []() { ThrowIfError(MyFunc()); } +// - Classes with a single argument: +// py_class.def("delete", [](Buffer& self) { +// ThrowIfError(self.Delete()); +// } +// +// For functions with more arguments, you can either inline the arguments, +// or use the `ThrowIfErrorWrapper` wrapper defined below: +// +// m.def("my_func", ThrowIfErrorWrapper(MyFunc)); +// +// Nonstatic member functions can be wrapped by passing a +// pointer-to-member-function: +// ThrowIfErrorWrapper(&MyClass::MyMethod) + +inline void ThrowIfError(absl::Status src) { + if (!src.ok()) { + throw std::runtime_error(src.ToString()); + } +} + +// If one does not want to have to define a lambda specifying the inputs +// arguments, on can use the `ThrowIfErrorWrapper` wrapper. +// +// There are three specializations: +// - For free functions, `Sig` is the function type and `F` is `Sig&`. +// - For callable types, `Sig` is the pointer to member function type +// and `F` is the type of the callable. +// - For a nonstatic member function of a class `C`, `Sig` is the function type +// and `F` is Sig C::*. +// +// In the first two cases, the wrapper returns a callable with signature `Sig`; +// in the third case, the wrapper returns callable with a modified signature +// that takes a C instance as the first argument. +template +struct ThrowIfErrorWrapper; + +// C++17 "deduction guide" that guides class template argument deduction (CTAD) +// For free functions. +template +ThrowIfErrorWrapper(F) -> ThrowIfErrorWrapper; + +// For callable types (with operator()). +template +ThrowIfErrorWrapper(absl::Status (&)(Args...)) + -> ThrowIfErrorWrapper; + +// For unbound nonstatic member functions. +template +ThrowIfErrorWrapper(absl::Status (C::*)(Args...)) + -> ThrowIfErrorWrapper; + +// Template specializations. + +// For free functions. +template +struct ThrowIfErrorWrapper { + explicit ThrowIfErrorWrapper(absl::Status (&f)(Args...)) : func(f) {} + void operator()(Args... args) { + ThrowIfError(func(std::forward(args)...)); + } + absl::Status (&func)(Args...); +}; + +// For callable types (with operator()), non-const and const versions. +template +struct ThrowIfErrorWrapper { + explicit ThrowIfErrorWrapper(F &&f) : func(std::move(f)) {} + void operator()(Args... args) { + ThrowIfError(func(std::forward(args)...)); + } + F func; +}; +template +struct ThrowIfErrorWrapper { + explicit ThrowIfErrorWrapper(F &&f) : func(std::move(f)) {} + void operator()(Args... args) const { + ThrowIfError(func(std::forward(args)...)); + } + F func; +}; + +// For unbound nonstatic member functions, non-const and const versions. +// `ptmf` stands for "pointer to member function". +template +struct ThrowIfErrorWrapper { + explicit ThrowIfErrorWrapper(absl::Status (C::*ptmf)(Args...)) : ptmf(ptmf) {} + void operator()(C &instance, Args... args) { + ThrowIfError((instance.*ptmf)(std::forward(args)...)); + } + absl::Status (C::*ptmf)(Args...); +}; +template +struct ThrowIfErrorWrapper { + explicit ThrowIfErrorWrapper(absl::Status (C::*ptmf)(Args...) const) + : ptmf(ptmf) {} + void operator()(const C &instance, Args... args) const { + ThrowIfError((instance.*ptmf)(std::forward(args)...)); + } + absl::Status (C::*ptmf)(Args...) const; +}; + +// Utilities for `StatusOr`. +template +T ValueOrThrow(absl::StatusOr v) { + if (!v.ok()) { + throw std::runtime_error(v.status().ToString()); + } + return std::move(v).value(); +} + +template +struct ValueOrThrowWrapper; + +template +ValueOrThrowWrapper(F) -> ValueOrThrowWrapper; + +template +ValueOrThrowWrapper(absl::StatusOr (&)(Args...)) + -> ValueOrThrowWrapper(Args...), + absl::StatusOr (&)(Args...)>; + +template +ValueOrThrowWrapper(absl::StatusOr (C::*)(Args...)) + -> ValueOrThrowWrapper(Args...), C>; + +// Deduction guide for const methods. +template +ValueOrThrowWrapper(absl::StatusOr (C::*)(Args...) const) + -> ValueOrThrowWrapper(Args...) const, C>; + +template +struct ValueOrThrowWrapper(Args...), + absl::StatusOr (&)(Args...)> { + explicit ValueOrThrowWrapper(absl::StatusOr (&f)(Args...)) : func(f) {} + R operator()(Args... args) const { + return ValueOrThrow(func(std::forward(args)...)); + } + absl::StatusOr (&func)(Args...); +}; +template +struct ValueOrThrowWrapper (C::*)(Args...), F> { + explicit ValueOrThrowWrapper(F &&f) : func(std::move(f)) {} + R operator()(Args... args) const { + return ValueOrThrow(func(std::forward(args)...)); + } + F func; +}; +template +struct ValueOrThrowWrapper (C::*)(Args...) const, F> { + explicit ValueOrThrowWrapper(F &&f) : func(std::move(f)) {} + R operator()(Args... args) const { + return ValueOrThrow(func(std::forward(args)...)); + } + F func; +}; + +// For unbound nonstatic member functions, non-const and const versions. +// `ptmf` stands for "pointer to member function". +template +struct ValueOrThrowWrapper(Args...), C> { + explicit ValueOrThrowWrapper(absl::StatusOr (C::*ptmf)(Args...)) + : ptmf(ptmf) {} + R operator()(C &instance, Args... args) { + return ValueOrThrow((instance.*ptmf)(std::forward(args)...)); + } + absl::StatusOr (C::*ptmf)(Args...); +}; +template +struct ValueOrThrowWrapper(Args...) const, C> { + explicit ValueOrThrowWrapper(absl::StatusOr (C::*ptmf)(Args...) const) + : ptmf(ptmf) {} + R operator()(const C &instance, Args... args) const { + return ValueOrThrow((instance.*ptmf)(std::forward(args)...)); + } + absl::StatusOr (C::*ptmf)(Args...) const; +}; + +} // namespace jax + +#endif // JAXLIB_ABSL_STATUS_CASTERS_H_ diff --git a/tests/ci_clangformat/callback.cc b/tests/ci_clangformat/callback.cc new file mode 100644 index 0000000..4e0530c --- /dev/null +++ b/tests/ci_clangformat/callback.cc @@ -0,0 +1,173 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/callback.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "jaxlib/python_ref_manager.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "xla/pjrt/host_callback.h" +#include "xla/pjrt/transpose.h" +#include "xla/primitive_util.h" +#include "xla/python/nb_numpy.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/python/lib/core/numpy.h" +#include "xla/xla_data.pb.h" + +namespace nb = nanobind; + +namespace xla { + +CpuCallback::~CpuCallback() { + // The destructor may be called without GIL held. In that case, we defer it + // to GlobalPyRefManager. + std::vector objects; + objects.push_back(std::move(callable_)); + for (auto &arg : args_) { + objects.push_back(std::move(arg.dtype)); + } + + GlobalPyRefManager()->AddGarbage(absl::MakeSpan(objects)); +} + +absl::Status CpuCallback::PrepareAndCall(void *result, void **arg_ptrs) { + absl::Span inputs(arg_ptrs, args_.size()); + absl::Span outputs(reinterpret_cast(result), + results_.size()); + + nb::gil_scoped_acquire gil; + nb::tuple args = nb::steal(PyTuple_New(inputs.size())); + for (size_t i = 0; i < inputs.size(); ++i) { + if (args_[i].type == xla::TOKEN) { + PyTuple_SET_ITEM(args.ptr(), i, nb::none().release().ptr()); + } else { + nb_numpy_ndarray array = + nb_numpy_ndarray(args_[i].dtype, args_[i].dims, args_[i].strides, + const_cast(inputs[i])); + array.attr("flags").attr("writeable") = nb::bool_(false); + PyTuple_SET_ITEM(args.ptr(), i, array.release().ptr()); + } + } + + EnterHostCallback(); + absl::StatusOr maybe_result_tuple = Call(std::move(args)); + LeaveHostCallback(); + TF_ASSIGN_OR_RETURN(auto result_tuple, maybe_result_tuple); + + for (size_t i = 0; i < results_.size(); ++i) { + if (results_[i].type == xla::TOKEN) { + continue; + } + nb::object output = + nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); + nb_numpy_ndarray array = nb_numpy_ndarray::ensure(std::move(output)); + absl::Span dims( + reinterpret_cast(array.shape()), array.ndim()); + absl::Span strides( + reinterpret_cast(array.strides()), array.ndim()); + if (strides == results_[i].expected_strides) { + std::memcpy(outputs[i], array.data(), results_[i].size_in_bytes); + } else { + xla::TransposePlan::Options options; + options.elem_size_in_bytes = + xla::primitive_util::ByteWidth(results_[i].type); + options.dims = dims; + options.permutation = results_[i].reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + absl::StatusOr> plan = + transpose_cache_.GetOrCreate(options); + if (!plan.ok()) { + return std::move(plan).status(); + } + plan.value()->Execute(array.data(), outputs[i]); + } + } + + return absl::OkStatus(); +} + +absl::StatusOr CpuCallback::Call(nb::tuple args) { + auto py_error_to_status = [](nb::python_error &e) { + std::string error_message = e.what(); + return absl::InternalError( + absl::StrFormat("CpuCallback error: %s", error_message)); + }; + nb::object result_object; + try { + result_object = callable_(*nb::borrow(args)); + } catch (nb::python_error &e) { + return py_error_to_status(e); + } + if (!PyTuple_Check(result_object.ptr())) { + return absl::InternalError( + absl::StrFormat("CPU callback expected a tuple result, got %s", + nb::cast(nb::repr(result_object)))); + } + if (PyTuple_Size(result_object.ptr()) != results_.size()) { + return absl::InternalError( + absl::StrFormat("CPU callback expected a tuple with %d results, got %d", + results_.size(), PyTuple_Size(result_object.ptr()))); + } + nb::tuple result_tuple = nb::cast(result_object); + for (size_t i = 0; i < results_.size(); ++i) { + nb::object output = + nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); + if (results_[i].type == xla::TOKEN) { + if (!output.is_none()) { + return absl::InternalError(absl::StrFormat( + "Token output from Python callback should be None, got %s", + nb::cast(nb::repr(output)))); + } + continue; + } + nb_numpy_ndarray array; + try { + array = nb_numpy_ndarray::from_any(output, NPY_ARRAY_ENSUREARRAY); + } catch (nb::python_error &e) { + return py_error_to_status(e); + } + static_assert(sizeof(ssize_t) == sizeof(int64_t), + "Expected ssize_t to be of equal size to int64_t"); + absl::Span dims( + reinterpret_cast(array.shape()), array.ndim()); + if (dims != results_[i].expected_dims) { + return absl::InternalError(absl::StrFormat( + "Mismatched result shape for %d-th return value from CPU callback; " + "expected array with dimensions %s, got %s", + i, absl::StrJoin(results_[i].expected_dims, ","), + absl::StrJoin(dims, ","))); + } + } + return result_tuple; +} + +} // namespace xla diff --git a/tests/ci_clangformat/callback.h b/tests/ci_clangformat/callback.h new file mode 100644 index 0000000..67104e1 --- /dev/null +++ b/tests/ci_clangformat/callback.h @@ -0,0 +1,87 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_CALLBACK_H_ +#define JAXLIB_CALLBACK_H_ + +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "xla/pjrt/transpose.h" +#include "xla/python/nb_numpy.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +class CpuCallback { + public: + struct Arg { + xla::PrimitiveType type; // XLA type + nb_dtype dtype; // NumPy type, for array types. + absl::InlinedVector dims; // Dimensions, for array types. + std::vector strides; // Byte strides, for array types. + size_t size_in_bytes; // Size of the array in bytes. + }; + struct Result { + xla::PrimitiveType type; // XLA type + // Expected output shape, for array types + absl::InlinedVector expected_dims; + // Expected output byte strides, for array types. If the strides do not + // match the output will be transposed into the expected layout. + std::vector expected_strides; + // The desired order of output dimensions in major-to-minor order. + absl::InlinedVector reversed_layout; + // Size of the array in bytes. + size_t size_in_bytes; + }; + + explicit CpuCallback(nanobind::callable callable, std::vector args, + std::vector results) + : callable_(std::move(callable)), + args_(std::move(args)), + results_(std::move(results)), + transpose_cache_(/*capacity=*/16) {} + + ~CpuCallback(); + + const std::vector &args() const { return args_; } + size_t num_args() const { return args_.size(); } + + const std::vector &results() const { return results_; } + size_t num_results() const { return results_.size(); } + void *callback() const { return callable_.ptr(); } + + xla::TransposePlanCache &transpose_cache() { return transpose_cache_; } + + absl::Status PrepareAndCall(void *result, void **arg_ptrs); + + absl::StatusOr Call(nanobind::tuple args); + + private: + nanobind::callable callable_; + std::vector args_; + std::vector results_; + xla::TransposePlanCache transpose_cache_; +}; + +} // namespace xla + +#endif // JAXLIB_CALLBACK_H_ diff --git a/tests/ci_clangformat/config.cc b/tests/ci_clangformat/config.cc new file mode 100644 index 0000000..ae525c9 --- /dev/null +++ b/tests/ci_clangformat/config.cc @@ -0,0 +1,348 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/config.h" + +#include + +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_set.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "jaxlib/python_ref_manager.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "xla/tsl/platform/logging.h" + +namespace jax { + +namespace nb = nanobind; + +// Singleton object used to represet "value not set" in thread-local configs. +nb::object UnsetObject() { + return nb::steal(PyObject_CallObject( + reinterpret_cast(&PyBaseObject_Type), nullptr)); +} + +// Each configuration object has: +// * a global value, and +// * a thread-local value. +// When querying the state of a config, the thread-local value is used if it is +// set. Otherwise, the global value is used. + +// This class represents all of the thread-local configuration state for a +// thread. +class ThreadLocalConfigState { + public: + ThreadLocalConfigState(); + ~ThreadLocalConfigState(); + + static ThreadLocalConfigState &Instance() { + thread_local auto state = std::make_unique(); + return *state; + } + + nb::object Get(int key) { + DCHECK_GE(key, 0); + return key >= entries_.size() ? nb::object() : entries_[key]; + } + + void Set(int key, nb::object value); + + private: + friend class GlobalConfigState; + + // These values are accessed in one of two ways: + // * The owning thread reads or writes them, while holding the GIL, or, under + // free-threading, while the owning thread is in ATTACHED gc state. + // * Other threads may read or clear values while performing a garbarge + // collection. + // No locking is needed because a GC thread cannot run concurrently with other + // Python threads; even under free-threading Python uses a stop-the-world GC. + std::vector entries_; +}; + +// This class represents all of the global configuration state. +// TODO(phawkins): to support free-threading, we will need to add locking to +// this class. +class GlobalConfigState { + public: + static GlobalConfigState &Instance() { + static auto state = new GlobalConfigState(); + return *state; + } + + nb::object Get(int key) const; + void Set(int key, nb::object value); + + // Adds or removes a thread-local state from the set of thread-local states. + void AddThreadLocalState(ThreadLocalConfigState *state) { + absl::MutexLock lock(&mu_); + thread_local_states_.insert(state); + } + void RemoveThreadLocalState(ThreadLocalConfigState *state) { + absl::MutexLock lock(&mu_); + thread_local_states_.erase(state); + } + + // Python GC helpers. These are called from the tp_traverse and tp_clear + // methods of the Config class. + int tp_traverse(int key, PyObject *self, visitproc visit, void *arg); + int tp_clear(int key, PyObject *self); + + // Returns the singleton object representing "value not set". + const nb::object &unset() const { return unset_; } + + // Returns the set of keys that should be included in the jit key. + absl::Span include_in_jit_key() const { + return include_in_jit_key_; + } + + private: + friend class Config; + + // The set of thread-local states. This is used during garbarge collection to + // visit thread-local values. + absl::Mutex mu_; + absl::flat_hash_set thread_local_states_ + ABSL_GUARDED_BY(mu_); + std::vector entries_; + std::vector include_in_jit_key_; + nb::object unset_ = UnsetObject(); +}; + +ThreadLocalConfigState::ThreadLocalConfigState() { + GlobalConfigState::Instance().AddThreadLocalState(this); +} + +ThreadLocalConfigState::~ThreadLocalConfigState() { + // It's important that we remove the thread-local state before we access + // entries_. This ensures that accesses to entries_ are ordered with respect + // any garbage collection. + GlobalConfigState::Instance().RemoveThreadLocalState(this); + // We do not hold the GIL, so we must use deferred destruction. + xla::GlobalPyRefManager()->AddGarbage(absl::MakeSpan(entries_)); +} + +void ThreadLocalConfigState::Set(int key, nb::object value) { + DCHECK_GE(key, 0); + if (key >= entries_.size()) { + entries_.resize(key + 1); + } + std::swap(entries_[key], value); +} + +nb::object GlobalConfigState::Get(int key) const { + DCHECK_GE(key, 0); + DCHECK_LT(key, entries_.size()); + return entries_[key]; +} + +void GlobalConfigState::Set(int key, nb::object value) { + DCHECK_GE(key, 0); + DCHECK_LT(key, entries_.size()); + std::swap(entries_[key], value); +} + +int GlobalConfigState::tp_traverse(int key, PyObject *self, visitproc visit, + void *arg) { + DCHECK_GE(key, 0); + if (key < entries_.size()) { + PyObject *value = entries_[key].ptr(); + Py_VISIT(value); + } + absl::MutexLock lock(&mu_); + for (const auto *state : thread_local_states_) { + if (key < state->entries_.size()) { + PyObject *value = state->entries_[key].ptr(); + Py_VISIT(value); + } + } + return 0; +} + +int GlobalConfigState::tp_clear(int key, PyObject *self) { + if (key < entries_.size()) { + nb::object tmp; + std::swap(entries_[key], tmp); + } + // We destroy the python objects outside of the lock out of an abundance of + // caution. + std::vector to_destroy; + absl::MutexLock lock(&mu_); + to_destroy.reserve(thread_local_states_.size()); + for (auto *state : thread_local_states_) { + if (key < state->entries_.size()) { + nb::object tmp; + std::swap(state->entries_[key], tmp); + to_destroy.push_back(std::move(tmp)); + } + } + return 0; +} + +// A Config object represents a configurable object with both global and +// thread-local state. This class is wrapped using nanobind and exposed to +// Python. +class Config { + public: + Config(nb::object value, bool include_in_jit_key); + + // Returns the thread-local value if it is set, otherwise the global value. + nb::object Get(); + + // Returns the global value. + nb::object GetGlobal(); + + // Sets the global value. + void SetGlobal(nb::object value); + + // Returns the thread-local value. + nb::object GetLocal(); + + // Sets the thread-local value. May be `unset`. + void SetLocal(nb::object value); + + // Swaps the thread-local value with `value`. Returns the previous value. + // Either may be `unset`. + nb::object SwapLocal(nb::object value); + + // This class doesn't actually hold any data, but it's the only type + // known to Python. We pretend that this object owns both the global and any + // thread-local values corresponding to this key. + static int tp_traverse(PyObject *self, visitproc visit, void *arg); + static int tp_clear(PyObject *self); + static PyType_Slot slots_[]; + + private: + int key_; +}; + +Config::Config(nb::object value, bool include_in_jit_key) { + auto &instance = GlobalConfigState::Instance(); + key_ = instance.entries_.size(); + instance.entries_.push_back(std::move(value)); + if (include_in_jit_key) { + instance.include_in_jit_key_.push_back(key_); + } +} + +nb::object Config::GetLocal() { + nb::object result = ThreadLocalConfigState::Instance().Get(key_); + if (!result.is_valid()) { + return GlobalConfigState::Instance().unset(); + } + return result; +} + +nb::object Config::GetGlobal() { + return GlobalConfigState::Instance().Get(key_); +} + +nb::object Config::Get() { + nb::object local = ThreadLocalConfigState::Instance().Get(key_); + if (local.is_valid()) { + return local; + } + return GetGlobal(); +} + +void Config::SetLocal(nb::object value) { + const auto &instance = GlobalConfigState::Instance(); + if (value.ptr() == instance.unset().ptr()) { + value = nb::object(); + } + ThreadLocalConfigState::Instance().Set(key_, std::move(value)); +} + +nb::object Config::SwapLocal(nb::object value) { + const auto &global_instance = GlobalConfigState::Instance(); + auto &instance = ThreadLocalConfigState::Instance(); + auto result = instance.Get(key_); + if (value.ptr() == global_instance.unset().ptr()) { + value = nb::object(); + } + instance.Set(key_, std::move(value)); + if (!result.is_valid()) { + return global_instance.unset(); + } + return result; +} + +void Config::SetGlobal(nb::object value) { + GlobalConfigState::Instance().Set(key_, value); +} + +/* static */ int Config::tp_traverse(PyObject *self, visitproc visit, + void *arg) { + Py_VISIT(Py_TYPE(self)); + if (!nb::inst_ready(self)) { + return 0; + } + Config *c = nb::inst_ptr(self); + // For the purposes of GC, we pretend that this object owns both the global + // and any thread-local values corresponding to this key. + return GlobalConfigState::Instance().tp_traverse(c->key_, self, visit, arg); +} + +/* static */ int Config::tp_clear(PyObject *self) { + Config *c = nb::inst_ptr(self); + return GlobalConfigState::Instance().tp_clear(c->key_, self); +} + +PyType_Slot Config::slots_[] = { + {Py_tp_traverse, reinterpret_cast(Config::tp_traverse)}, + {Py_tp_clear, reinterpret_cast(Config::tp_clear)}, + {0, nullptr}, +}; + +void BuildConfigSubmodule(nanobind::module_ &m) { + nb::module_ config_module = m.def_submodule("config", "Config library"); + + config_module.attr("unset") = GlobalConfigState::Instance().unset(); + + nb::class_ config(config_module, "Config", + nb::type_slots(Config::slots_), nb::is_generic()); + config.def(nb::init(), nb::arg("value").none(), + nb::arg("include_in_jit_key") = false); + config.def_prop_ro("value", &Config::Get); + config.def("get_local", &Config::GetLocal); + config.def("get_global", &Config::GetGlobal); + config.def("set_local", &Config::SetLocal, nb::arg("value").none()); + config.def("swap_local", &Config::SwapLocal, nb::arg("value").none()); + config.def("set_global", &Config::SetGlobal, nb::arg("value").none()); +} + +std::vector JitConfigs() { + auto &instance = GlobalConfigState::Instance(); + auto &thread_local_instance = ThreadLocalConfigState::Instance(); + std::vector result; + result.reserve(instance.include_in_jit_key().size()); + for (int i : instance.include_in_jit_key()) { + nb::object local = thread_local_instance.Get(i); + if (local.is_valid()) { + result.push_back(std::move(local)); + } else { + result.push_back(instance.Get(i)); + } + } + return result; +} + +} // namespace jax diff --git a/tests/ci_clangformat/config.h b/tests/ci_clangformat/config.h new file mode 100644 index 0000000..a4e39b1 --- /dev/null +++ b/tests/ci_clangformat/config.h @@ -0,0 +1,34 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_CONFIG_H_ +#define JAXLIB_CONFIG_H_ + +#include + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace jax { + +// Returns the set of configuration values that should be included in the JIT +// cache key. +std::vector JitConfigs(); + +void BuildConfigSubmodule(nanobind::module_ &m); + +} // namespace jax + +#endif // JAXLIB_CONFIG_H_ diff --git a/tests/ci_clangformat/custom_call_sharding.cc b/tests/ci_clangformat/custom_call_sharding.cc new file mode 100644 index 0000000..3928b03 --- /dev/null +++ b/tests/ci_clangformat/custom_call_sharding.cc @@ -0,0 +1,346 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/custom_call_sharding.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/tuple.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/utils/hlo_sharding_util.h" +#include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/c/pjrt_c_api_custom_partitioner_extension.h" +#include "xla/pjrt/c/pjrt_c_api_helpers.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/custom_call_batch_partitioner.h" +#include "xla/python/custom_partition_callback.h" +#include "xla/python/inspect_sharding.h" +#include "xla/shape.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace xla { + +namespace nb = ::nanobind; + +class PyCustomCallPartitionerCallbacks { + public: + PyCustomCallPartitionerCallbacks(nb::object prop_user_sharding, + nb::object partition, + nb::object infer_sharding_from_operands) + : prop_user_sharding_(prop_user_sharding), + partition_(partition), + infer_sharding_from_operands_(infer_sharding_from_operands) { + callbacks_.version = 0; + callbacks_.private_data = this; + callbacks_.dtor = +[](JAX_CustomCallPartitioner_Callbacks *self) { + delete GetSelfPtr(self); + }; + callbacks_.partition = +[](JAX_CustomCallPartitioner_Callbacks *self, + JAX_CustomCallPartitioner_Partition_Args *args) { + jax::PopulateResults(GetSelfPtr(self)->CallPartition(args), args); + }; + callbacks_.infer_sharding = + +[](JAX_CustomCallPartitioner_Callbacks *self, + JAX_CustomCallPartitioner_InferShardingFromOperands_Args *args) { + jax::PopulateResults( + GetSelfPtr(self)->CallInferShardingFromOperands(args), args); + }; + callbacks_.propagate_user_sharding = + +[](JAX_CustomCallPartitioner_Callbacks *self, + JAX_CustomCallPartitioner_PropagateUserSharding_Args *args) { + jax::PopulateResults( + GetSelfPtr(self)->CallPropagateUserSharding(args), args); + }; + } + + absl::StatusOr< + std::tuple, xla::HloSharding>> + CallPartition(JAX_CustomCallPartitioner_Partition_Args *args) const { + if (args->header.api_version != 0) { + return absl::InternalError("API version mismatch."); + } + TF_ASSIGN_OR_RETURN(auto args_tuple, jax::ReadArgs(args)); + std::vector shapes = std::move(std::get<0>(args_tuple)); + std::vector> shardings = + std::move(std::get<1>(args_tuple)); + xla::Shape result_shape = std::move(std::get<2>(args_tuple)); + std::optional result_sharding = + std::move(std::get<3>(args_tuple)); + absl::string_view backend_config = std::move(std::get<4>(args_tuple)); + + { + nb::gil_scoped_acquire gil; + try { + auto py_result = + partition_(shapes, shardings, result_shape, result_sharding, + nb::bytes(backend_config.data(), backend_config.size())); + try { + auto [ir, arg_shardings, result_sharding] = nb::cast< + std::tuple, HloSharding>>( + py_result); + if (arg_shardings.size() != args->num_args) { + return xla::Internal( + "Shardings returned from partitioning: lengths must match: %d " + "vs %d", + arg_shardings.size(), args->num_args); + } + return std::make_tuple(std::string(ir.c_str(), ir.size()), + std::move(arg_shardings), + std::move(result_sharding)); + } catch (const nb::cast_error &e) { + return xla::Internal( + "Shardings returned from partitioning: expected " + "Tuple[bytes, List[HloSharding], HloSharding] got: %s", + nb::cast(nb::repr(py_result))); + } + } catch (const nb::python_error &e) { + return xla::Internal("custom_partitioner: %s", e.what()); + } + } + } + + absl::StatusOr> CallInferShardingFromOperands( + JAX_CustomCallPartitioner_InferShardingFromOperands_Args *args) const { + if (args->header.api_version != 0) { + return absl::InternalError("API version mismatch."); + } + TF_ASSIGN_OR_RETURN(auto args_tuple, jax::ReadArgs(args)); + std::vector arg_shapes = std::move(std::get<0>(args_tuple)); + std::vector> arg_shardings = + std::move(std::get<1>(args_tuple)); + xla::Shape result_shape = std::move(std::get<2>(args_tuple)); + absl::string_view backend_config = std::move(std::get<3>(args_tuple)); + + std::optional result; + nb::gil_scoped_acquire gil; + try { + auto py_result = infer_sharding_from_operands_( + arg_shapes, arg_shardings, result_shape, + nb::bytes(backend_config.data(), backend_config.size())); + if (py_result.is_none()) { + return std::nullopt; + } + return nb::cast(py_result); + } catch (const nb::python_error &e) { + return xla::Internal("custom_partitioner: %s", e.what()); + } + } + + absl::StatusOr CallPropagateUserSharding( + JAX_CustomCallPartitioner_PropagateUserSharding_Args *args) const { + if (args->header.api_version != 0) { + return absl::InternalError("API version mismatch."); + } + TF_ASSIGN_OR_RETURN(auto args_tuple, jax::ReadArgs(args)); + xla::HloSharding result_sharding = std::move(std::get<0>(args_tuple)); + xla::Shape result_shape = std::move(std::get<1>(args_tuple)); + absl::string_view backend_config = std::move(std::get<2>(args_tuple)); + + nb::gil_scoped_acquire gil; + try { + // TODO(parkers): expand this API to handle the `user` sharding. + // The user is used when the custom call returns a Tuple and + // the user is a get-tuple-element. In this case we must update only + // part of the sharding spec. + auto result = nb::cast(prop_user_sharding_( + result_sharding, result_shape, + nb::bytes(backend_config.data(), backend_config.size()))); + return result; + } catch (const nb::python_error &e) { + return xla::Internal("custom_partitioner: %s", e.what()); + } + } + + JAX_CustomCallPartitioner_Callbacks *callbacks() { return &callbacks_; } + + private: + static PyCustomCallPartitionerCallbacks *GetSelfPtr( + JAX_CustomCallPartitioner_Callbacks *callbacks) { + return reinterpret_cast( + callbacks->private_data); + } + + JAX_CustomCallPartitioner_Callbacks callbacks_; + nb::object prop_user_sharding_; + nb::object partition_; + nb::object infer_sharding_from_operands_; +}; + +namespace { + +void CallInspectSharding(void *obj, JAX_InspectSharding_Callback_Args *args) { + std::optional arg = jax::InspectShardingReadArgs(args); + if (!arg.has_value()) { + return; + } + try { + nb::gil_scoped_acquire gil; + nb::handle(reinterpret_cast(obj))(*std::move(arg)); + } catch (const nb::python_error &e) { + jax::InspectShardingSetError(args, std::string(e.what())); + } +} + +} // namespace + +void BuildCustomCallShardingPybindAPI(nb::module_ &m) { + m.def( + "register_custom_call_partitioner", + [](std::string name, nb::object prop_user_sharding, nb::object partition, + nb::object infer_sharding_from_operands, + bool can_side_effecting_have_replicated_sharding, + std::optional c_api) { + auto *c_fns = + (new PyCustomCallPartitionerCallbacks(prop_user_sharding, partition, + infer_sharding_from_operands)) + ->callbacks(); + c_fns->can_side_effecting_have_replicated_sharding = + can_side_effecting_have_replicated_sharding; + if (!c_api.has_value()) { + RegisterCustomCallPartitioner( + name, jax::CreateCApiCustomCallPartitioner(c_fns)); + return; + } + + if (absl::string_view(c_api->name()) != "pjrt_c_api") { + throw absl::InvalidArgumentError( + "Argument to register_custom_call_partitioner was not a " + "pjrt_c_api capsule."); + } + auto *c_api_value = static_cast(c_api->data()); + PJRT_Custom_Partitioner_Extension *extension = + pjrt::FindExtension( + c_api_value, + PJRT_Extension_Type::PJRT_Extension_Type_Custom_Partitioner); + if (extension == nullptr) { + return; + } + PJRT_Register_Custom_Partitioner_Args args; + args.struct_size = PJRT_Register_Custom_Partitioner_Args_STRUCT_SIZE; + args.name = name.c_str(); + args.name_size = name.size(); + args.callbacks = c_fns; + PJRT_Error *error = + reinterpret_cast( + extension) + ->register_custom_partitioner(&args); + std::unique_ptr error_ptr( + error, pjrt::MakeErrorDeleter(c_api_value)); + ThrowIfError(pjrt::PjrtErrorToStatus(error_ptr.get(), c_api_value)); + }, + R"(Registers a partitioner for a custom-call operation. + +Args: + name: custom_call_target to match. + prop_user_sharding: Custom backwards sharding propagation rule. + Takes result sharding and returns the instruction sharding. + partition: Lowering rule. Takes operand and result shardings and returns + a generated HLO and sharding specs. The spmd lowerer first reshards + to match the returned sharding specs and then inserts the generated hlo. + infer_sharding_from_operands: Custom forwards sharding propagation rule. + Takes operand sharding and returns the instruction sharding. + can_side_effecting_have_replicated_sharding: Side effecting ops are not + allowed to have replicated sharding. Pass true to disable this check. + c_api: Optional `PJRT_Api*` if it is called with a plugin. This is safe to + call on plugins that do not implement the custom partitioner extension +)", + nb::arg("name"), nb::arg("prop_user_sharding"), nb::arg("partition"), + nb::arg("infer_sharding_from_operands"), + nb::arg("can_side_effecting_have_replicated_sharding") = false, + nb::arg("c_api").none() = std::nullopt); + m.def("encode_inspect_sharding_callback", + [](nb::object handler) -> nb::bytes { + JAX_InspectSharding_Callback cb; + cb.call = &CallInspectSharding; + cb.data = handler.ptr(); + char bytes[sizeof(JAX_InspectSharding_Callback)]; + std::memcpy(&bytes, &cb, sizeof(JAX_InspectSharding_Callback)); + return nb::bytes(bytes, sizeof(JAX_InspectSharding_Callback)); + }); + + nb::module_ hlo_sharding_util_m = m.def_submodule( + "hlo_sharding_util", "Utilities for manipulating HloSharding."); + hlo_sharding_util_m.def( + "PartiallyReplicateTiledShardingOnDims", + [](const HloSharding &sharding, std::vector dims) { + return hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + sharding, dims); + }); + + m.def( + "register_custom_call_as_batch_partitionable", + [](std::string target_name, std::optional c_api) { + if (!c_api.has_value()) { + RegisterCustomCallPartitioner( + target_name, std::make_unique()); + return; + } + if (absl::string_view(c_api->name()) != "pjrt_c_api") { + throw absl::InvalidArgumentError( + "Argument to register_custom_call_partitioner was not a " + "pjrt_c_api capsule."); + } + auto *c_api_value = static_cast(c_api->data()); + PJRT_Custom_Partitioner_Extension *extension = + pjrt::FindExtension( + c_api_value, + PJRT_Extension_Type::PJRT_Extension_Type_Custom_Partitioner); + if (extension == nullptr) { + return; + } + PJRT_Register_Batch_Partitionable_Args args; + args.struct_size = PJRT_Register_Batch_Partitionable_Args_STRUCT_SIZE; + args.name = target_name.c_str(); + args.name_size = target_name.size(); + PJRT_Error *error = extension->register_batch_partitionable(&args); + std::unique_ptr error_ptr( + error, pjrt::MakeErrorDeleter(c_api_value)); + ThrowIfError(pjrt::PjrtErrorToStatus(error_ptr.get(), c_api_value)); + }, + R"(Registers a custom call as batch partitionable. + +If a custom call is "batch partitionable", it means that it can be trivially +partitioned on some number of (leading) dimensions, with the same call being +executed independently on each shard of data. If the data are sharded on +non-batch dimensions, partitioning will re-shard the data to be replicated on +the non-batch dimensions. + +Args: + target_name: the target name of the batch partitionable custom call. + c_api: optional `PJRT_Api*` to support registration via a PJRT plugin. +)", + nb::arg("target_name"), nb::arg("c_api").none() = std::nullopt); +} + +} // namespace xla diff --git a/tests/ci_clangformat/custom_call_sharding.h b/tests/ci_clangformat/custom_call_sharding.h new file mode 100644 index 0000000..ab9711d --- /dev/null +++ b/tests/ci_clangformat/custom_call_sharding.h @@ -0,0 +1,28 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_CUSTOM_CALL_SHARDING_H_ +#define JAXLIB_CUSTOM_CALL_SHARDING_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace xla { + +void BuildCustomCallShardingPybindAPI(nanobind::module_ &m); + +} // namespace xla + +#endif // JAXLIB_CUSTOM_CALL_SHARDING_H_ diff --git a/tests/ci_clangformat/dlpack.cc b/tests/ci_clangformat/dlpack.cc new file mode 100644 index 0000000..34f7596 --- /dev/null +++ b/tests/ci_clangformat/dlpack.cc @@ -0,0 +1,503 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/dlpack.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "include/dlpack/dlpack.h" +#include "jaxlib/dlpack_support.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_client.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/traceback.h" +#include "jaxlib/util.h" +#include "nanobind/nanobind.h" +#include "nanobind/ndarray.h" +#include "xla/layout.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/types.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "llvm/Support/Casting.h" + +namespace nb = nanobind; + +namespace xla { +namespace { + +const char *const kDlTensorCapsuleName = "dltensor"; + +struct DLPackTensor { + ~DLPackTensor(); + + // `buffer_reference` is populated if we have shared (read-only) access. + nb::object buffer_reference; + + // `external_reference` is always populated. + std::unique_ptr external_reference; + + std::vector shape; + std::vector strides; + DLManagedTensor tensor; +}; + +DLPackTensor::~DLPackTensor() { + if (buffer_reference) { + GlobalPyRefManager()->AddGarbage( + absl::MakeSpan(&buffer_reference, /*size=*/1)); + } +} + +void DLPackTensorDeleter(DLManagedTensor *t) { + if (t) { + delete static_cast(t->manager_ctx); + } +} + +absl::StatusOr> StridesToLayout( + absl::Span dims, absl::Span strides) { + CHECK_EQ(dims.size(), strides.size()); + std::vector minor_to_major(dims.size()); + std::iota(minor_to_major.begin(), minor_to_major.end(), 0); + absl::c_sort(minor_to_major, [&](int a, int b) { + if (strides[a] < strides[b]) { + return true; + } + if (strides[a] > strides[b]) { + return false; + } + // If two dimensions have the same stride, prefer the major-to-minor + // interpretation of the ordering, since that's what JAX wants. + return b < a; + }); + int64_t stride = 1; + for (int64_t d : minor_to_major) { + if (dims[d] > 1 && strides[d] != stride) { + return Unimplemented( + "Only DLPack tensors with trivial (compact) striding are supported; " + "i.e., tensors whose striding represents a transposition of the " + "underlying buffer but not broadcasting. Dimensions were: [%s], " + "strides were [%s].", + absl::StrJoin(dims, ","), absl::StrJoin(strides, ",")); + } + stride *= dims[d]; + } + return minor_to_major; +} + +absl::StatusOr DLDeviceTypeForDevice(const PjRtDevice &device) { + if (device.client()->platform_id() == CpuId()) { + return kDLCPU; + } else if (device.client()->platform_id() == CudaId()) { + return kDLCUDA; + } else if (device.client()->platform_id() == RocmId()) { + return kDLROCM; + } + return InvalidArgument("Device %s cannot be used as a DLPack device.", + device.DebugString()); +} + +absl::StatusOr DLDeviceForDevice(const PjRtDevice &device) { + DLDevice context; + TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device)); + context.device_id = device.local_hardware_id().value(); + return context; +} + +absl::StatusOr DeviceForDLDevice(const PjRtClient *cpu_client, + const PjRtClient *gpu_client, + const DLDevice &context) { + switch (context.device_type) { + case kDLCPU: + if (cpu_client == nullptr) { + return InvalidArgument( + "DLPack tensor is on CPU, but no CPU backend was provided."); + } + TF_RET_CHECK(cpu_client->platform_id() == CpuId()); + return cpu_client->LookupAddressableDevice( + xla::PjRtLocalDeviceId(context.device_id)); + case kDLCUDA: + if (gpu_client == nullptr) { + return InvalidArgument( + "DLPack tensor is on GPU, but no GPU backend was provided."); + } + TF_RET_CHECK(gpu_client->platform_id() == CudaId()); + return gpu_client->LookupAddressableDevice( + xla::PjRtLocalDeviceId(context.device_id)); + case kDLROCM: + if (gpu_client == nullptr) { + return InvalidArgument( + "DLPack tensor is on GPU, but no GPU backend was provided."); + } + TF_RET_CHECK(gpu_client->platform_id() == RocmId()); + return gpu_client->LookupAddressableDevice( + xla::PjRtLocalDeviceId(context.device_id)); + default: + return InvalidArgument("Unknown/unsupported DLPack device type %d", + context.device_type); + } +} + +absl::Status VerifyDType(const DLTensor &dl_tensor) { + if (dl_tensor.dtype.bits % 8 != 0) { + return InvalidArgument( + "Unsupported DLPack tensor dtype: bits should be a multiple of 8, got " + "%d", + dl_tensor.dtype.bits); + } + + if (dl_tensor.dtype.lanes != 1) { + return InvalidArgument( + "Unsupported DLPack tensor dtype: lanes should be equal to 1, got %d", + dl_tensor.dtype.lanes); + } + + return absl::OkStatus(); +} + +absl::StatusOr> GetByteStrides(const DLTensor &dl_tensor) { + TF_RETURN_IF_ERROR(VerifyDType(dl_tensor)); + + // Convert element strides from the number of elements to the number of bytes. + std::vector strides; + strides.reserve(dl_tensor.ndim); + for (int i = 0; i < dl_tensor.ndim; ++i) { + strides.push_back(dl_tensor.strides[i] * dl_tensor.dtype.bits / 8); + } + return strides; +} + +absl::StatusOr> MakePjrtBuffer( + PjRtDevice &device, ::DLManagedTensor *dlmt, const Shape &shape, + PrimitiveType element_type, absl::Span dimensions, + std::optional stream = std::nullopt) { + std::function on_delete_callback; + if (dlmt->deleter) { + on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); }; + } + + // First try to create a view. + void *data = + static_cast(dlmt->dl_tensor.data) + dlmt->dl_tensor.byte_offset; + auto result = device.client()->CreateViewOfDeviceBuffer( + data, shape, *device.default_memory_space(), on_delete_callback, stream); + + // If that fails with invalid argument, it's possibly because of the incorrect + // alignment. If we're on CPU, we can create a copy of buffer. + if (result.status().code() == absl::StatusCode::kInvalidArgument && + dlmt->dl_tensor.device.device_type == kDLCPU) { + LOG(WARNING) << "DLPack buffer is not aligned (data at: " << data + << "). Creating a copy."; + + // Convert tensor strides (expressed in number of elements) to byte strides. + std::optional> byte_strides; + if (dlmt->dl_tensor.strides) { + TF_ASSIGN_OR_RETURN(byte_strides, GetByteStrides(dlmt->dl_tensor)); + } + + TF_ASSIGN_OR_RETURN(auto *memory_space, device.default_memory_space()); + + // Create a copy. + result = device.client()->BufferFromHostBuffer( + data, element_type, dimensions, byte_strides, + PjRtClient::HostBufferSemantics::kMutableZeroCopy, on_delete_callback, + memory_space, /*device_layout=*/nullptr); + } + return result; +} + +} // namespace + +absl::StatusOr BufferToDLPackManagedTensor( + nb::handle py_buffer, std::optional stream) { + ifrt::Array *ifrt_array = nb::cast(py_buffer).ifrt_array(); + if (ifrt_array == nullptr) { + return Unimplemented( + "BufferToDLPackManagedTensor called on deleted array."); + } + auto *arr = llvm::dyn_cast_or_null(ifrt_array); + if (arr == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + PjRtBuffer *pjrt_buffer = arr->pjrt_buffers().front().get(); + + if (pjrt_buffer->IsTuple()) { + return Unimplemented( + "BufferToDLPackManagedTensor is not implemented for tuple " + "buffers."); + } + if (pjrt_buffer->has_dynamic_dimensions()) { + return Unimplemented("DynamicShape is not implemented in DLPack."); + } + + auto pack = std::make_unique(); + DLTensor &dt = pack->tensor.dl_tensor; + { + // AcquireExternalReference may block; there are no API guarantees. + GlobalPyRefManager()->CollectGarbage(); + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(pack->external_reference, + pjrt_buffer->AcquireExternalReference()); + if (stream) { + TF_RETURN_IF_ERROR( + pack->external_reference->WaitUntilBufferReadyOnStream(*stream)); + } else { + TF_RETURN_IF_ERROR( + AwaitBuffersReady(absl::MakeConstSpan(&ifrt_array, 1))); + } + } + pack->buffer_reference = nb::borrow(py_buffer); + + dt.data = pack->external_reference->OpaqueDeviceMemoryDataPointer(); + pack->tensor.manager_ctx = pack.get(); + pack->tensor.deleter = DLPackTensorDeleter; + TF_ASSIGN_OR_RETURN(dt.device, DLDeviceForDevice(*pjrt_buffer->device())); + dt.device.device_id = pjrt_buffer->device()->local_hardware_id().value(); + dt.ndim = pjrt_buffer->dimensions().size(); + TF_ASSIGN_OR_RETURN(dt.dtype, + PrimitiveTypeToDLDataType(pjrt_buffer->element_type())); + + pack->shape = std::vector(pjrt_buffer->dimensions().begin(), + pjrt_buffer->dimensions().end()); + + // TODO(b/327524065): use PjRtLayout directly instead of xla::Layout + Layout xla_layout = pjrt_buffer->layout()->xla_layout(); + pack->strides = StridesForShape(pjrt_buffer->element_type(), + pjrt_buffer->dimensions(), xla_layout); + + dt.shape = reinterpret_cast(pack->shape.data()); + dt.strides = reinterpret_cast(pack->strides.data()); + dt.byte_offset = 0; + + // We cannot use nanobind's capsule object constructor because we need to + // detect if the capsule name has been changed in the deleter, but nanobind + // hides the underlying Python object from the deleter. + nb::capsule capsule = nb::steal( + PyCapsule_New(&pack.release()->tensor, kDlTensorCapsuleName, + [](PyObject *obj) noexcept { + DLManagedTensor *dlmt = static_cast( + PyCapsule_GetPointer(obj, kDlTensorCapsuleName)); + if (dlmt) { + DLPackTensorDeleter(dlmt); + } else { + // The tensor has been deleted. Clear any error from + // PyCapsule_GetPointer. + PyErr_Clear(); + } + })); + if (!capsule.ptr()) { + throw nb::python_error(); + } + return capsule; +} + +absl::StatusOr DLPackManagedTensorToBuffer( + const nb::capsule &tensor, std::optional> cpu_client, + std::optional> gpu_client) { + // TODO(hyeontaek): This is a potential target for an IFRT client to multiplex + // multiple PjRt clients. Devices from these PjRt clients could be expressed + // as a unified set of IFRT devices. + auto *cpu_pjrt_client = cpu_client ? (*cpu_client)->pjrt_client() : nullptr; + auto *gpu_pjrt_client = gpu_client ? (*gpu_client)->pjrt_client() : nullptr; + + if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) { + return InvalidArgument( + "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". " + "Note that a DLPack tensor may be consumed at most once.", + absl::string_view(tensor.name())); + } + DLManagedTensor *dlmt = static_cast(tensor.data()); + if (dlmt->dl_tensor.ndim < 0) { + return InvalidArgument( + "Number of dimensions in DLManagedTensor must be nonnegative, got %d", + dlmt->dl_tensor.ndim); + } + TF_ASSIGN_OR_RETURN(PjRtDevice * device, + DeviceForDLDevice(cpu_client ? cpu_pjrt_client : nullptr, + gpu_client ? gpu_pjrt_client : nullptr, + dlmt->dl_tensor.device)); + absl::Span dimensions( + reinterpret_cast(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim); + TF_ASSIGN_OR_RETURN(PrimitiveType element_type, + DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype)); + + std::vector minor_to_major; + if (dlmt->dl_tensor.strides && + absl::c_find(dimensions, 0) == dimensions.end()) { + absl::Span strides( + reinterpret_cast(dlmt->dl_tensor.strides), + dlmt->dl_tensor.ndim); + TF_ASSIGN_OR_RETURN(minor_to_major, StridesToLayout(dimensions, strides)); + } else { + minor_to_major.resize(dlmt->dl_tensor.ndim); + std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); + } + Shape shape = ShapeUtil::MakeShapeWithDenseLayout(element_type, dimensions, + minor_to_major); + + // Raise an error if the resulting PjRtBuffer would have a non-default layout. + // TODO(skyewm): we do this because JAX doesn't currently have good support + // for non-default layouts, and will return wrong results if a non-default + // layout is passed to a computation expecting default layouts. Remove this + // special case when non-default layouts are better supported by JAX. + TF_ASSIGN_OR_RETURN(Layout default_layout, device->client()->GetDefaultLayout( + element_type, dimensions)); + if (shape.layout() != default_layout) { + return Unimplemented( + "from_dlpack got array with non-default layout with minor-to-major " + "dimensions (%s), expected (%s)", + absl::StrJoin(shape.layout().minor_to_major(), ","), + absl::StrJoin(default_layout.minor_to_major(), ",")); + } + + std::function on_delete_callback; + if (dlmt->deleter) { + on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); }; + } + + TF_ASSIGN_OR_RETURN( + auto pjrt_buffer, + MakePjrtBuffer(*device, dlmt, shape, element_type, dimensions)); + + // We have taken ownership of the array inside the capsule; make sure the + // capsule it cannot be used again. + PyCapsule_SetName(tensor.ptr(), "used_dltensor"); + PyCapsule_SetDestructor(tensor.ptr(), nullptr); + // TODO(phawkins): simplify the expression below once we know cpu_client is + // always non-null. + auto client = (cpu_client && device->client() == cpu_pjrt_client) + ? std::move(*cpu_client) + : std::move(*gpu_client); + auto *ifrt_client = + llvm::dyn_cast_or_null(client->ifrt_client()); + if (ifrt_client == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + TF_ASSIGN_OR_RETURN(auto ifrt_array, + ifrt_client->CreatePjRtArray(std::move(pjrt_buffer))); + return PyArray::MakeFromSingleDeviceArray(std::move(client), Traceback::Get(), + std::move(ifrt_array), false, true); +} + +absl::StatusOr DLPackManagedTensorToBuffer( + const nb::capsule &tensor, ifrt::Device *ifrt_device, + nb_class_ptr client, std::optional stream) { + ifrt::PjRtDevice *device = + llvm::dyn_cast_or_null(ifrt_device); + if (device == nullptr) { + throw XlaRuntimeError( + "DLPack is supported for PjRt-compatible backends only."); + } + if (!device->IsAddressable()) { + throw XlaRuntimeError( + "DLPack is only supported for devices addressable by the current " + "process."); + } + if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) { + return InvalidArgument( + "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". " + "Note that a DLPack tensor may be consumed at most once.", + absl::string_view(tensor.name())); + } + DLManagedTensor *dlmt = static_cast(tensor.data()); + if (dlmt->dl_tensor.ndim < 0) { + return InvalidArgument( + "Number of dimensions in DLManagedTensor must be nonnegative, got %d", + dlmt->dl_tensor.ndim); + } + absl::Span dimensions( + reinterpret_cast(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim); + TF_ASSIGN_OR_RETURN(PrimitiveType element_type, + DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype)); + + std::vector minor_to_major; + if (dlmt->dl_tensor.strides && + absl::c_find(dimensions, 0) == dimensions.end()) { + absl::Span strides( + reinterpret_cast(dlmt->dl_tensor.strides), + dlmt->dl_tensor.ndim); + TF_ASSIGN_OR_RETURN(minor_to_major, StridesToLayout(dimensions, strides)); + } else { + minor_to_major.resize(dlmt->dl_tensor.ndim); + std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); + } + Shape shape = ShapeUtil::MakeShapeWithDenseLayout(element_type, dimensions, + minor_to_major); + + TF_ASSIGN_OR_RETURN(auto pjrt_buffer, + MakePjrtBuffer(*device->pjrt_device(), dlmt, shape, + element_type, dimensions, stream)); + + // We have taken ownership of the array inside the capsule; make sure the + // capsule it cannot be used again. + PyCapsule_SetName(tensor.ptr(), "used_dltensor"); + PyCapsule_SetDestructor(tensor.ptr(), nullptr); + + auto *ifrt_client = + llvm::dyn_cast_or_null(client->ifrt_client()); + if (ifrt_client == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + TF_ASSIGN_OR_RETURN(auto ifrt_array, + ifrt_client->CreatePjRtArray(std::move(pjrt_buffer))); + return PyArray::MakeFromSingleDeviceArray(std::move(client), Traceback::Get(), + std::move(ifrt_array), false, true); +} + +absl::StatusOr PrimitiveTypeToNbDLDataType( + PrimitiveType type) { + TF_ASSIGN_OR_RETURN(DLDataType dl_type, PrimitiveTypeToDLDataType(type)); + + nanobind::dlpack::dtype nb_type; + nb_type.lanes = dl_type.lanes; + nb_type.bits = dl_type.bits; + nb_type.code = dl_type.code; + + return nb_type; +} + +} // namespace xla diff --git a/tests/ci_clangformat/dlpack.h b/tests/ci_clangformat/dlpack.h new file mode 100644 index 0000000..7659d66 --- /dev/null +++ b/tests/ci_clangformat/dlpack.h @@ -0,0 +1,58 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_DLPACK_H_ +#define JAXLIB_DLPACK_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "nanobind/nanobind.h" +#include "nanobind/ndarray.h" +#include "xla/python/ifrt/device.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// If take_ownership is true, ownership of the buffer is handed to DLPack, and +// the receiver may mutate the buffer as they see fit. Otherwise PjRt retains +// ownership of the buffer and it should be immutable. +// +// stream, if set, is a GPU stream, e.g. cudaStream_t for CUDA GPUs, that should +// be synchronized to the buffer as per +// https://dmlc.github.io/dlpack/latest/python_spec.html#python-specification-for-dlpack. +absl::StatusOr BufferToDLPackManagedTensor( + nanobind::handle buffer, std::optional stream); + +absl::StatusOr DLPackManagedTensorToBuffer( + const nanobind::capsule &tensor, + std::optional> cpu_client, + std::optional> gpu_client); + +absl::StatusOr DLPackManagedTensorToBuffer( + const nanobind::capsule &tensor, ifrt::Device *device, + nb_class_ptr client, std::optional stream); + +// Converts a PrimitiveType to the nanobind specific implementation of +// DLDataType. +absl::StatusOr PrimitiveTypeToNbDLDataType( + PrimitiveType type); + +} // namespace xla + +#endif // JAXLIB_DLPACK_H_ diff --git a/tests/ci_clangformat/dlpack_support.cc b/tests/ci_clangformat/dlpack_support.cc new file mode 100644 index 0000000..9e85184 --- /dev/null +++ b/tests/ci_clangformat/dlpack_support.cc @@ -0,0 +1,223 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/dlpack_support.h" + +#include "absl/status/statusor.h" +#include "include/dlpack/dlpack.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +absl::StatusOr PrimitiveTypeToDLDataType(PrimitiveType type) { + switch (type) { + case S8: + return DLDataType{kDLInt, 8, 1}; + case S16: + return DLDataType{kDLInt, 16, 1}; + case S32: + return DLDataType{kDLInt, 32, 1}; + case S64: + return DLDataType{kDLInt, 64, 1}; + case U8: + return DLDataType{kDLUInt, 8, 1}; + case U16: + return DLDataType{kDLUInt, 16, 1}; + case U32: + return DLDataType{kDLUInt, 32, 1}; + case U64: + return DLDataType{kDLUInt, 64, 1}; + case F4E2M1FN: + return DLDataType{kDLFloat4_e2m1fn, 4, 1}; + case F8E3M4: + return DLDataType{kDLFloat8_e3m4, 8, 1}; + case F8E4M3: + return DLDataType{kDLFloat8_e4m3, 8, 1}; + case F8E4M3B11FNUZ: + return DLDataType{kDLFloat8_e4m3b11fnuz, 8, 1}; + case F8E4M3FN: + return DLDataType{kDLFloat8_e4m3fn, 8, 1}; + case F8E4M3FNUZ: + return DLDataType{kDLFloat8_e4m3fnuz, 8, 1}; + case F8E5M2: + return DLDataType{kDLFloat8_e5m2, 8, 1}; + case F8E5M2FNUZ: + return DLDataType{kDLFloat8_e5m2fnuz, 8, 1}; + case F8E8M0FNU: + return DLDataType{kDLFloat8_e8m0fnu, 8, 1}; + case BF16: + return DLDataType{kDLBfloat, 16, 1}; + case F16: + return DLDataType{kDLFloat, 16, 1}; + case F32: + return DLDataType{kDLFloat, 32, 1}; + case F64: + return DLDataType{kDLFloat, 64, 1}; + case PRED: + return DLDataType{kDLBool, 8, 1}; + case C64: + return DLDataType{kDLComplex, 64, 1}; + case C128: + return DLDataType{kDLComplex, 128, 1}; + default: + return Unimplemented("XLA type %s has no DLPack equivalent", + PrimitiveType_Name(type)); + } +} + +absl::StatusOr DLDataTypeToPrimitiveType(DLDataType type) { + if (type.lanes != 1) { + return Unimplemented("DLPack types with lanes != 1 not implemented, got %d", + type.lanes); + } + switch (type.code) { + case kDLBool: + switch (type.bits) { + case 8: + return PRED; + default: + return Unimplemented( + "Only 8-bit DLPack booleans are supported, got %d bits", + type.bits); + } + case kDLInt: + switch (type.bits) { + case 8: + return S8; + case 16: + return S16; + case 32: + return S32; + case 64: + return S64; + default: + return Unimplemented( + "Invalid or unsupported DLPack integer width: %d bits", + type.bits); + } + case kDLUInt: + switch (type.bits) { + case 8: + return U8; + case 16: + return U16; + case 32: + return U32; + case 64: + return U64; + default: + return Unimplemented( + "Invalid or unsupported DLPack unsigned integer width: %d bits", + type.bits); + } + case kDLFloat4_e2m1fn: + if (type.bits == 4) { + return F4E2M1FN; + } + return Unimplemented( + "Invalid or unsupported DLPack float4_e2m1fn width: %d bits", + type.bits); + case kDLFloat8_e3m4: + if (type.bits == 8) { + return F8E3M4; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e3m4 width: %d bits", + type.bits); + case kDLFloat8_e4m3: + if (type.bits == 8) { + return F8E4M3; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e4m3 width: %d bits", + type.bits); + case kDLFloat8_e4m3b11fnuz: + if (type.bits == 8) { + return F8E4M3B11FNUZ; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e4m3b11fnuz width: %d bits", + type.bits); + case kDLFloat8_e4m3fn: + if (type.bits == 8) { + return F8E4M3FN; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e4m3fn width: %d bits", + type.bits); + case kDLFloat8_e4m3fnuz: + if (type.bits == 8) { + return F8E4M3FNUZ; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e4m3fnuz width: %d bits", + type.bits); + case kDLFloat8_e5m2: + if (type.bits == 8) { + return F8E5M2; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e5m2 width: %d bits", + type.bits); + case kDLFloat8_e5m2fnuz: + if (type.bits == 8) { + return F8E5M2FNUZ; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e5m2fnuz width: %d bits", + type.bits); + case kDLFloat8_e8m0fnu: + if (type.bits == 8) { + return F8E8M0FNU; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e8m0fnu width: %d bits", + type.bits); + case kDLBfloat: + if (type.bits == 16) { + return BF16; + } + return Unimplemented( + "Invalid or unsupported DLPack bfloat width: %d bits", type.bits); + case kDLFloat: + switch (type.bits) { + case 16: + return F16; + case 32: + return F32; + case 64: + return F64; + default: + return Unimplemented( + "Invalid or unsupported DLPack float width: %d bits", type.bits); + } + case kDLComplex: + switch (type.bits) { + case 64: + return C64; + case 128: + return C128; + default: + return Unimplemented( + "Invalid or unsupported DLPack complex width: %d bits", + type.bits); + } + default: + return Unimplemented("Unknown or invalid DLPack type code %d", type.code); + } +} + +} // namespace xla diff --git a/tests/ci_clangformat/dlpack_support.h b/tests/ci_clangformat/dlpack_support.h new file mode 100644 index 0000000..25e8623 --- /dev/null +++ b/tests/ci_clangformat/dlpack_support.h @@ -0,0 +1,30 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_DLPACK_SUPPORT_H_ +#define JAXLIB_XLA_DLPACK_SUPPORT_H_ + +#include "absl/status/statusor.h" +#include "include/dlpack/dlpack.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +absl::StatusOr PrimitiveTypeToDLDataType(PrimitiveType type); +absl::StatusOr DLDataTypeToPrimitiveType(DLDataType type); + +} // namespace xla + +#endif // JAXLIB_XLA_DLPACK_SUPPORT_H_ diff --git a/tests/ci_clangformat/ffi.cc b/tests/ci_clangformat/ffi.cc new file mode 100644 index 0000000..5e0778c --- /dev/null +++ b/tests/ci_clangformat/ffi.cc @@ -0,0 +1,374 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/ffi.h" + +#include + +#include +#include +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "include/dlpack/dlpack.h" +#include "jaxlib/dlpack_support.h" +#include "nanobind/nanobind.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" +#include "xla/ffi/ffi_api.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/types.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/xla_data.pb.h" + +namespace jax { + +namespace ffi = xla::ffi; +namespace nb = nanobind; + +namespace { +const char *const kDlTensorCapsuleName = "dltensor"; +const char *const kDlTensorVersionedCapsuleName = "dltensor_versioned"; + +template +struct DLPackTensor { + std::vector shape; + ManagedTensor tensor; +}; + +template +void DLPackTensorDeleter(ManagedTensor *t) { + if (t) { + delete static_cast *>(t->manager_ctx); + } +} + +xla::PrimitiveType PrimitiveTypeForFfiDataType(ffi::DataType dtype) { + switch (dtype) { + case ffi::DataType::INVALID: + return xla::PrimitiveType::PRIMITIVE_TYPE_INVALID; + case ffi::PRED: + return xla::PrimitiveType::PRED; + case ffi::S1: + return xla::PrimitiveType::S1; + case ffi::S2: + return xla::PrimitiveType::S2; + case ffi::S4: + return xla::PrimitiveType::S4; + case ffi::S8: + return xla::PrimitiveType::S8; + case ffi::S16: + return xla::PrimitiveType::S16; + case ffi::S32: + return xla::PrimitiveType::S32; + case ffi::S64: + return xla::PrimitiveType::S64; + case ffi::U1: + return xla::PrimitiveType::U1; + case ffi::U2: + return xla::PrimitiveType::U2; + case ffi::U4: + return xla::PrimitiveType::U4; + case ffi::U8: + return xla::PrimitiveType::U8; + case ffi::U16: + return xla::PrimitiveType::U16; + case ffi::U32: + return xla::PrimitiveType::U32; + case ffi::U64: + return xla::PrimitiveType::U64; + case ffi::F16: + return xla::PrimitiveType::F16; + case ffi::F32: + return xla::PrimitiveType::F32; + case ffi::F64: + return xla::PrimitiveType::F64; + case ffi::BF16: + return xla::PrimitiveType::BF16; + case ffi::C64: + return xla::PrimitiveType::C64; + case ffi::C128: + return xla::PrimitiveType::C128; + case ffi::TOKEN: + return xla::PrimitiveType::TOKEN; + case ffi::F8E5M2: + return xla::PrimitiveType::F8E5M2; + case ffi::F8E4M3: + return xla::PrimitiveType::F8E4M3; + case ffi::F8E4M3FN: + return xla::PrimitiveType::F8E4M3FN; + case ffi::F8E4M3B11FNUZ: + return xla::PrimitiveType::F8E4M3B11FNUZ; + case ffi::F8E5M2FNUZ: + return xla::PrimitiveType::F8E5M2FNUZ; + case ffi::F8E4M3FNUZ: + return xla::PrimitiveType::F8E4M3FNUZ; + case ffi::F8E3M4: + return xla::PrimitiveType::F8E3M4; + case ffi::F4E2M1FN: + return xla::PrimitiveType::F4E2M1FN; + case ffi::F8E8M0FNU: + return xla::PrimitiveType::F8E8M0FNU; + } +} +} // namespace + +PyFfiContext::PyFfiContext(const XLA_FFI_Api *api, + XLA_FFI_ExecutionContext *ctx, + XLA_FFI_ExecutionStage stage) + : api_(api), ctx_(ctx), stage_(stage) {} + +PyFfiContext::Stage PyFfiContext::stage() const { + return static_cast(stage_); +} + +absl::StatusOr PyFfiContext::stream() const { + XLA_FFI_Stream_Get_Args args; + args.struct_size = XLA_FFI_Stream_Get_Args_STRUCT_SIZE; + args.extension_start = nullptr; + args.ctx = ctx_; + args.stream = nullptr; + if (XLA_FFI_Error *error = api_->XLA_FFI_Stream_Get(&args)) { + return ffi::TakeStatus(error); + } + return absl::bit_cast(args.stream); +} + +PyFfiAnyBuffer::PyFfiAnyBuffer(DLDeviceType device_type, int32_t device_ordinal, + void *data, ffi::Span dimensions, + ffi::DataType element_type, bool writeable) + : device_type_(device_type), + device_ordinal_(device_ordinal), + data_(data), + dimensions_(dimensions.begin(), dimensions.size()), + element_type_(PrimitiveTypeForFfiDataType(element_type)), + writeable_(writeable) {} + +PyFfiAnyBuffer::PyFfiAnyBuffer(DLDeviceType device_type, int32_t device_ordinal, + ffi::AnyBuffer buf) + : PyFfiAnyBuffer(device_type, device_ordinal, buf.untyped_data(), + buf.dimensions(), buf.element_type(), + /*writeable=*/false) {} + +PyFfiAnyBuffer::PyFfiAnyBuffer(DLDeviceType device_type, int32_t device_ordinal, + ffi::Result buf) + : PyFfiAnyBuffer(device_type, device_ordinal, buf->untyped_data(), + buf->dimensions(), buf->element_type(), + /*writeable=*/true) {} + +absl::StatusOr PyFfiAnyBuffer::dtype() const { + return xla::PrimitiveTypeToNbDtype(element_type_); +} + +size_t PyFfiAnyBuffer::ndim() const { return dimensions_.size(); } + +nb::tuple PyFfiAnyBuffer::shape() const { + return xla::SpanToNbTuple(dimensions_); +} + +bool PyFfiAnyBuffer::writeable() const { return writeable_; } + +absl::StatusOr PyFfiAnyBuffer::NumpyArray() const { + if (device_type_ != kDLCPU) { + return absl::UnimplementedError( + "Buffer.__array__ is only supported on CPU."); + } + + TF_ASSIGN_OR_RETURN(auto dtype, this->dtype()); + xla::nb_numpy_ndarray array(dtype, dimensions_, /* strides= */ std::nullopt, + data_, nb::cast(this)); + + // TODO(danfm): We don't seem to be allowed to set this flag like this + // because the array doesn't own its data. + // array.attr("flags").attr("writeable") = nb::bool_(writeable_); + + return array; +} + +absl::StatusOr PyFfiAnyBuffer::CudaArrayInterface() const { + if (device_type_ != kDLCUDA) { + return absl::UnimplementedError( + "Buffer.__cuda_array_interface__ is only supported on CUDA."); + } + + nb::dict result; + result["shape"] = xla::SpanToNbTuple(dimensions_); + TF_ASSIGN_OR_RETURN(result["typestr"], + TypeDescriptorForPrimitiveType(element_type_)); + result["data"] = nb::make_tuple( + nb::int_(absl::bit_cast(data_)), !writeable_); + result["version"] = nb::int_(2); + return result; +} + +absl::StatusOr PyFfiAnyBuffer::DLPack() const { + auto pack = std::make_unique>(); + pack->tensor.manager_ctx = pack.get(); + pack->tensor.deleter = DLPackTensorDeleter; + + DLTensor &dt = pack->tensor.dl_tensor; + dt.data = data_; + dt.device = DLDevice{device_type_, device_ordinal_}; + dt.ndim = dimensions_.size(); + TF_ASSIGN_OR_RETURN(dt.dtype, xla::PrimitiveTypeToDLDataType(element_type_)); + pack->shape = std::vector(dimensions_.begin(), dimensions_.end()); + dt.shape = reinterpret_cast(pack->shape.data()); + dt.strides = nullptr; + dt.byte_offset = 0; + + // We cannot use nanobind's capsule object constructor because we need to + // detect if the capsule name has been changed in the deleter, but nanobind + // hides the underlying Python object from the deleter. + nb::capsule capsule = nb::steal( + PyCapsule_New(&pack.release()->tensor, kDlTensorCapsuleName, + [](PyObject *obj) noexcept { + DLManagedTensor *dlmt = static_cast( + PyCapsule_GetPointer(obj, kDlTensorCapsuleName)); + if (dlmt) { + DLPackTensorDeleter(dlmt); + } else { + // The tensor has been deleted. Clear any error from + // PyCapsule_GetPointer. + PyErr_Clear(); + } + })); + if (!capsule.ptr()) { + throw nb::python_error(); + } + + return capsule; +} + +absl::StatusOr PyFfiAnyBuffer::DLPackVersioned() const { + auto pack = std::make_unique>(); + pack->tensor.version = + DLPackVersion{DLPACK_MAJOR_VERSION, DLPACK_MINOR_VERSION}; + pack->tensor.manager_ctx = pack.get(); + pack->tensor.deleter = DLPackTensorDeleter; + pack->tensor.flags = writeable_ ? 0 : DLPACK_FLAG_BITMASK_READ_ONLY; + + DLTensor &dt = pack->tensor.dl_tensor; + dt.data = data_; + dt.device = DLDevice{device_type_, device_ordinal_}; + dt.ndim = dimensions_.size(); + TF_ASSIGN_OR_RETURN(dt.dtype, xla::PrimitiveTypeToDLDataType(element_type_)); + pack->shape = std::vector(dimensions_.begin(), dimensions_.end()); + dt.shape = reinterpret_cast(pack->shape.data()); + dt.strides = nullptr; + dt.byte_offset = 0; + + // We cannot use nanobind's capsule object constructor because we need to + // detect if the capsule name has been changed in the deleter, but nanobind + // hides the underlying Python object from the deleter. + nb::capsule capsule = nb::steal(PyCapsule_New( + &pack.release()->tensor, kDlTensorVersionedCapsuleName, + [](PyObject *obj) noexcept { + DLManagedTensorVersioned *dlmt = + static_cast( + PyCapsule_GetPointer(obj, kDlTensorVersionedCapsuleName)); + if (dlmt) { + DLPackTensorDeleter(dlmt); + } else { + // The tensor has been deleted. Clear any error from + // PyCapsule_GetPointer. + PyErr_Clear(); + } + })); + if (!capsule.ptr()) { + throw nb::python_error(); + } + + return capsule; +} + +nb::tuple PyFfiAnyBuffer::DLPackDevice() const { + return nb::make_tuple(static_cast(device_type_), device_ordinal_); +} + +void BuildFfiSubmodule(nb::module_ &m) { + tsl::ImportNumpy(); + + nb::module_ ffi_module = + m.def_submodule("ffi", "Python bindings for the XLA FFI."); + + nb::class_ buffer(ffi_module, "Buffer"); + buffer.def_prop_ro("dtype", xla::ValueOrThrowWrapper(&PyFfiAnyBuffer::dtype)); + buffer.def_prop_ro("ndim", &PyFfiAnyBuffer::ndim); + buffer.def_prop_ro("shape", &PyFfiAnyBuffer::shape); + buffer.def_prop_ro("writeable", &PyFfiAnyBuffer::writeable); + buffer.def( + "__array__", + [](PyFfiAnyBuffer self, nb::object dtype, nb::object copy) { + if (!dtype.is_none()) { + throw nb::value_error( + "dtype parameter is not supported by Buffer.__array__."); + } + if (!copy.is_none() && nb::cast(copy)) { + throw nb::value_error( + "Buffer.__array__ with copy=True is not supported."); + } + return xla::ValueOrThrow(self.NumpyArray()); + }, + nb::arg("dtype") = nb::none(), nb::arg("copy") = nb::none()); + buffer.def_prop_ro( + "__cuda_array_interface__", + xla::ValueOrThrowWrapper(&PyFfiAnyBuffer::CudaArrayInterface)); + buffer.def( + "__dlpack__", + [](PyFfiAnyBuffer self, nb::object stream, nb::object max_version, + nb::object dl_device, nb::object copy) { + if (!copy.is_none() && nb::cast(copy)) { + throw nb::value_error( + "Buffer.__dlpack__ with copy=True is not supported."); + } + + // Fall back on the non-versioned API if unsupported by the requested + // max_version. + nb::tuple max_version_tuple; + int64_t max_version_major; + if (!nb::try_cast(max_version, max_version_tuple) || + max_version_tuple.size() < 2 || + !nb::try_cast(max_version_tuple[0], max_version_major) || + max_version_major < 1) { + return xla::ValueOrThrow(self.DLPack()); + } + + // TODO(danfm): Handle other optional inputs. + return xla::ValueOrThrow(self.DLPackVersioned()); + }, + nb::arg("stream") = nb::none(), nb::arg("max_version") = nb::none(), + nb::arg("dl_device") = nb::none(), nb::arg("copy") = nb::none()); + buffer.def("__dlpack_device__", &PyFfiAnyBuffer::DLPackDevice); + + nb::enum_(ffi_module, "ExecutionStage") + .value("INSTANTIATE", PyFfiContext::Stage::kInstantiate) + .value("PREPARE", PyFfiContext::Stage::kPrepare) + .value("INITIALIZE", PyFfiContext::Stage::kInitialize) + .value("EXECUTE", PyFfiContext::Stage::kExecute) + .export_values(); + + nb::class_ context(ffi_module, "ExecutionContext"); + context.def_prop_ro("stage", &PyFfiContext::stage); + context.def_prop_ro("stream", + xla::ValueOrThrowWrapper(&PyFfiContext::stream)); +} + +} // namespace jax diff --git a/tests/ci_clangformat/ffi.h b/tests/ci_clangformat/ffi.h new file mode 100644 index 0000000..ff6ee13 --- /dev/null +++ b/tests/ci_clangformat/ffi.h @@ -0,0 +1,152 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_FFI_H_ +#define JAXLIB_XLA_FFI_H_ + +#include + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "include/dlpack/dlpack.h" +#include "nanobind/nanobind.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" +#include "xla/pjrt/host_callback.h" +#include "xla/python/nb_numpy.h" +#include "xla/xla_data.pb.h" + +namespace jax { + +namespace ffi = xla::ffi; +namespace nb = nanobind; + +// Wrapper class for XLA FFI execution context. +// +// This class provides a Python interface to the XLA FFI execution context, +// exposing metadata such as the execution stage, device ordinal, and stream. +class PyFfiContext { + public: + enum class Stage { + kInstantiate, + kPrepare, + kInitialize, + kExecute, + }; + + PyFfiContext(const XLA_FFI_Api *api, XLA_FFI_ExecutionContext *ctx, + XLA_FFI_ExecutionStage stage); + Stage stage() const; + absl::StatusOr stream() const; + + private: + const XLA_FFI_Api *api_; + XLA_FFI_ExecutionContext *ctx_; + XLA_FFI_ExecutionStage stage_; +}; + +// Wrapper class for XLA FFI AnyBuffer. +// +// This class provides a Python interface to the XLA FFI `AnyBuffer` class. +// From Python, this object looks like an array (with `.dtype` and `.shape` +// attributes), but it also provides methods zero-copy conversions to standard +// transport formats: `__array__`, `__cuda_array_interface__`, and `__dlpack__`. +class PyFfiAnyBuffer { + public: + PyFfiAnyBuffer(DLDeviceType device_type, int32_t device_ordinal, void *data, + ffi::Span dimensions, + ffi::DataType element_type, bool writeable); + PyFfiAnyBuffer(DLDeviceType device_type, int32_t device_ordinal, + ffi::AnyBuffer buf); + PyFfiAnyBuffer(DLDeviceType device_type, int32_t device_ordinal, + ffi::Result buf); + + absl::StatusOr dtype() const; + size_t ndim() const; + nb::tuple shape() const; + bool writeable() const; + + absl::StatusOr NumpyArray() const; + absl::StatusOr CudaArrayInterface() const; + absl::StatusOr DLPack() const; + absl::StatusOr DLPackVersioned() const; + nb::tuple DLPackDevice() const; + + private: + DLDeviceType device_type_; + int32_t device_ordinal_; + void *data_; + absl::Span dimensions_; + xla::PrimitiveType element_type_; + bool writeable_; +}; + +template +ffi::Error XlaBufferCallback(int32_t device_ordinal, const XLA_FFI_Api *api, + XLA_FFI_ExecutionContext *ctx, + xla::FfiLoadedHostCallbacks *callbacks, + uint64_t index, ffi::RemainingArgs args, + ffi::RemainingRets rets) { + nb::gil_scoped_acquire gil; + auto callback = nb::borrow( + static_cast(callbacks->callbacks[index])); + auto nb_args = + nb::steal(PyTuple_New(1 + args.size() + rets.size())); + + jax::PyFfiContext py_ctx(api, ctx, XLA_FFI_ExecutionStage_EXECUTE); + PyTuple_SET_ITEM(nb_args.ptr(), 0, nb::cast(py_ctx).release().ptr()); + + size_t offset = 1; + for (size_t i = 0; i < args.size(); ++i, ++offset) { + auto arg = args.get(i); + if (arg.has_error()) { + return arg.error(); + } + jax::PyFfiAnyBuffer py_buffer(DeviceType, device_ordinal, arg.value()); + PyTuple_SET_ITEM(nb_args.ptr(), offset, + nb::cast(py_buffer).release().ptr()); + } + + for (size_t i = 0; i < rets.size(); ++i, ++offset) { + auto ret = rets.get(i); + if (ret.has_error()) { + return ret.error(); + } + jax::PyFfiAnyBuffer py_buffer(DeviceType, device_ordinal, ret.value()); + PyTuple_SET_ITEM(nb_args.ptr(), offset, + nb::cast(py_buffer).release().ptr()); + } + + xla::EnterHostCallback(); + try { + callback(*nb::borrow(nb_args)); + } catch (nb::python_error &e) { + return ffi::Error::Internal( + absl::StrFormat("Error when calling buffer callback: %s", e.what())); + } + xla::LeaveHostCallback(); + + return ffi::Error::Success(); +} + +void BuildFfiSubmodule(nanobind::module_ &m); + +} // namespace jax + +#endif // JAXLIB_XLA_FFI_H_ diff --git a/tests/ci_clangformat/ffi_helpers.h b/tests/ci_clangformat/ffi_helpers.h new file mode 100644 index 0000000..3492276 --- /dev/null +++ b/tests/ci_clangformat/ffi_helpers.h @@ -0,0 +1,217 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_FFI_HELPERS_H_ +#define JAXLIB_FFI_HELPERS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/optimization.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" + +namespace jax { + +// Returns from the function if the argument is an ffi::Error. +#define FFI_RETURN_IF_ERROR(...) \ + do { \ + ::xla::ffi::Error err = (__VA_ARGS__); \ + if (ABSL_PREDICT_FALSE(err.failure())) { \ + return err; \ + } \ + } while (0) + +// Returns from the function with an ffi::Error if the argument is an +// absl::Status. +#define FFI_RETURN_IF_ERROR_STATUS(...) \ + do { \ + ::absl::Status status = (__VA_ARGS__); \ + if (ABSL_PREDICT_FALSE(!status.ok())) { \ + return ::jax::AsFfiError(status); \ + } \ + } while (0) + +// Returns from the function with an ffi::Error if the RHS is an absl::Status, +// otherwise assigns to the LHS. Most of the complication here stems from the +// fact that we want to support having the LHS wrapped in parentheses (when +// unpacking a tuple, for example). +#define FFI_ASSIGN_OR_RETURN(lhs, rhs) \ + FFI_ASSIGN_OR_RETURN_IMPL_( \ + FFI_ASSIGN_OR_RETURN_CONCAT_(_status_or_value, __LINE__), lhs, rhs) + +#define FFI_ASSIGN_OR_RETURN_IMPL_(statusor, lhs, rhs) \ + auto statusor = (rhs); \ + if (ABSL_PREDICT_FALSE(!statusor.ok())) { \ + return ::jax::AsFfiError(statusor.status()); \ + } \ + FFI_ASSIGN_OR_RETURN_UNPARENTHESIZE_IF_PARENTHESIZED(lhs) = \ + (*std::move(statusor)) + +#define FFI_ASSIGN_OR_RETURN_CONCAT_INNER_(x, y) x##y +#define FFI_ASSIGN_OR_RETURN_CONCAT_(x, y) \ + FFI_ASSIGN_OR_RETURN_CONCAT_INNER_(x, y) + +// All the macros below here are to handle the case in FFI_ASSIGN_OR_RETURN +// where the LHS is wrapped in parentheses. See a more detailed discussion at +// https://stackoverflow.com/a/62984543 +#define FFI_ASSIGN_OR_RETURN_UNPARENTHESIZE_IF_PARENTHESIZED(X) \ + FFI_ASSIGN_OR_RETURN_ESCAPE(FFI_ASSIGN_OR_RETURN_EMPTY X) +#define FFI_ASSIGN_OR_RETURN_EMPTY(...) FFI_ASSIGN_OR_RETURN_EMPTY __VA_ARGS__ +#define FFI_ASSIGN_OR_RETURN_ESCAPE(...) \ + FFI_ASSIGN_OR_RETURN_ESCAPE_(__VA_ARGS__) +#define FFI_ASSIGN_OR_RETURN_ESCAPE_(...) FFI_ASSIGN_OR_RETURN_##__VA_ARGS__ +#define FFI_ASSIGN_OR_RETURN_FFI_ASSIGN_OR_RETURN_EMPTY + +template +inline absl::StatusOr MaybeCastNoOverflow( + std::int64_t value, std::string_view source = __FILE__) { + if constexpr (sizeof(T) == sizeof(std::int64_t)) { + return value; + } else { + if (value > std::numeric_limits::max()) [[unlikely]] { + return absl::InvalidArgumentError(absl::StrFormat( + "%s: Value (=%d) exceeds the maximum representable value of the " + "desired type", + source, value)); + } + return static_cast(value); + } +} + +inline ::xla::ffi::Error AsFfiError(const absl::Status &status) { + if (ABSL_PREDICT_FALSE(!status.ok())) { + return ::xla::ffi::Error(static_cast(status.code()), + std::string(status.message())); + } else { + return ::xla::ffi::Error::Success(); + } +} + +inline int64_t GetBatchSize(::xla::ffi::Span dims) { + return absl::c_accumulate(dims, 1, std::multiplies()); +} + +inline absl::StatusOr> SplitBatch1D( + ::xla::ffi::Span dims, + const std::string &source = __FILE__) { + if (dims.size() < 1) { + return absl::InvalidArgumentError( + absl::StrFormat("%s: Argument must have at least 1 dimension", source)); + } + return std::make_pair(GetBatchSize(dims.first(dims.size() - 1)), dims.back()); +} + +inline absl::StatusOr> SplitBatch2D( + ::xla::ffi::Span dims, + const std::string &source = __FILE__) { + if (dims.size() < 2) { + return absl::InvalidArgumentError(absl::StrFormat( + "%s: Argument must have at least 2 dimensions", source)); + } + auto trailingDims = dims.last(2); + return std::make_tuple(GetBatchSize(dims.first(dims.size() - 2)), + trailingDims.front(), trailingDims.back()); +} + +inline ::xla::ffi::Error CheckShape(::xla::ffi::Span dimensions, + int64_t expected_batch, + std::string_view name, + std::string_view op) { + auto batch = GetBatchSize(dimensions); + if (batch != expected_batch) { + return ::xla::ffi::Error::InvalidArgument(absl::StrFormat( + "Invalid total batch size for input %s to %s. Expected %d, got %d.", + name, op, expected_batch, batch)); + } + return ::xla::ffi::Error::Success(); +} + +inline ::xla::ffi::Error CheckShape(::xla::ffi::Span dimensions, + std::tuple shape, + std::string_view name, + std::string_view op) { + FFI_ASSIGN_OR_RETURN((auto [batch, size]), SplitBatch1D(dimensions)); + auto [expected_batch, expected_size] = shape; + if (batch != expected_batch) { + return ::xla::ffi::Error::InvalidArgument(absl::StrFormat( + "Invalid total batch size for input %s to %s. Expected %d, got %d.", + name, op, expected_batch, batch)); + } + if (batch != expected_batch || size != expected_size) { + return ::xla::ffi::Error::InvalidArgument( + absl::StrFormat("Invalid trailing dimension for input %s " + "to %s. Expected %d, got %d.", + name, op, expected_size, size)); + } + return ::xla::ffi::Error::Success(); +} + +inline ::xla::ffi::Error CheckShape(::xla::ffi::Span dimensions, + std::tuple shape, + std::string_view name, + std::string_view op) { + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), SplitBatch2D(dimensions)); + auto [expected_batch, expected_rows, expected_cols] = shape; + if (batch != expected_batch) { + return ::xla::ffi::Error::InvalidArgument(absl::StrFormat( + "Invalid total batch size for input %s to %s. Expected %d, got %d.", + name, op, expected_batch, batch)); + } + if (rows != expected_rows || cols != expected_cols) { + return ::xla::ffi::Error::InvalidArgument( + absl::StrFormat("Invalid matrix dimensions for input %s to %s. " + "Expected (%d, %d), got (%d, %d).", + name, op, expected_rows, expected_cols, rows, cols)); + } + return ::xla::ffi::Error::Success(); +} + +template <::xla::ffi::DataType dtype> +auto AllocateScratchMemory(std::size_t size) + -> std::unique_ptr>[]> { + // TODO(paruzelp): use std::make_unique_for_overwrite when C++20 is available. + using ValueType = std::remove_extent_t<::xla::ffi::NativeType>; + return std::unique_ptr(new ValueType[size]); +} + +template +inline absl::StatusOr AllocateWorkspace( + ::xla::ffi::ScratchAllocator &scratch, int64_t size, + std::string_view name) { + auto maybe_workspace = scratch.Allocate(sizeof(T) * size); + if (!maybe_workspace.has_value()) { + return absl::Status( + absl::StatusCode::kResourceExhausted, + absl::StrFormat("Unable to allocate workspace for %s", name)); + } + return static_cast(maybe_workspace.value()); +} + +} // namespace jax + +#endif // JAXLIB_FFI_HELPERS_H_ diff --git a/tests/ci_clangformat/guard_lib.cc b/tests/ci_clangformat/guard_lib.cc new file mode 100644 index 0000000..71c34df --- /dev/null +++ b/tests/ci_clangformat/guard_lib.cc @@ -0,0 +1,197 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This files implements the configuration management for different types of +// guards. +// C++ backends are responsible for enforcing transfer guard levels. + +#include "jaxlib/guard_lib.h" + +#include +#include + +#include "absl/base/attributes.h" +#include "absl/functional/function_ref.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "xla/util.h" + +namespace jax { + +namespace nb = ::nanobind; + +namespace { + +// Protected by the GIL. +GuardState &global_state = *new GuardState(); + +ABSL_CONST_INIT thread_local GuardState thread_local_state; + +// The default transfer guard level. +constexpr TransferGuardLevel kDefaultGuardLevel = TransferGuardLevel::kAllow; + +// The default garbage collection guard level. +constexpr GarbageCollectionGuardLevel kDefaultGarbageCollectionGuardLevel = + GarbageCollectionGuardLevel::kAllow; + +// Returns the transfer guard action for a transfer. +TransferGuardAction GetTransferGuardAction(TransferGuardLevel guard_level, + bool explicit_transfer) { + switch (guard_level) { + case TransferGuardLevel::kAllow: + return TransferGuardAction::kAllow; + case TransferGuardLevel::kLog: + if (explicit_transfer) { + return TransferGuardAction::kAllow; + } else { + return TransferGuardAction::kLog; + } + case TransferGuardLevel::kDisallow: + if (explicit_transfer) { + return TransferGuardAction::kAllow; + } else { + return TransferGuardAction::kDisallow; + } + case TransferGuardLevel::kLogExplicit: + return TransferGuardAction::kLog; + case TransferGuardLevel::kDisallowExplicit: + return TransferGuardAction::kDisallow; + default: + // Unreachable; gracefully handle the unexpected guard level and prevent a + // compiler warning. + return TransferGuardAction::kDisallow; + } +} + +// Returns the transfer guard action for a host-to-device transfer. +// REQUIRES: Python GIL. +TransferGuardAction GetTransferGuardActionForHostToDevice() { + return GetTransferGuardAction( + thread_local_state.host_to_device.value_or( + global_state.host_to_device.value_or(kDefaultGuardLevel)), + thread_local_state.explicit_device_put); +} + +// Returns the transfer guard action for a device-to-device transfer. +// REQUIRES: Python GIL. +TransferGuardAction GetTransferGuardActionForDeviceToDevice() { + return GetTransferGuardAction( + thread_local_state.device_to_device.value_or( + global_state.device_to_device.value_or(kDefaultGuardLevel)), + thread_local_state.explicit_device_put); +} + +// Returns the transfer guard action for a device-to-host transfer. +// REQUIRES: Python GIL. +TransferGuardAction GetTransferGuardActionForDeviceToHost() { + return GetTransferGuardAction( + thread_local_state.device_to_host.value_or( + global_state.device_to_host.value_or(kDefaultGuardLevel)), + thread_local_state.explicit_device_get); +} + +} // namespace + +absl::Status ApplyTransferGuardToHostToDevice( + absl::FunctionRef formatter) { + switch (GetTransferGuardActionForHostToDevice()) { + case TransferGuardAction::kAllow: + break; + case TransferGuardAction::kLog: + LOG(WARNING) << "host-to-device transfer: " << formatter(); + break; + case TransferGuardAction::kDisallow: + return xla::InvalidArgument("Disallowed host-to-device transfer: %s", + formatter()); + } + return absl::OkStatus(); +} + +absl::Status ApplyTransferGuardToDeviceToDevice( + absl::FunctionRef formatter) { + switch (GetTransferGuardActionForDeviceToDevice()) { + case TransferGuardAction::kAllow: + break; + case TransferGuardAction::kLog: + LOG(WARNING) << "device-to-device transfer: " << formatter(); + break; + case TransferGuardAction::kDisallow: + return xla::InvalidArgument("Disallowed device-to-device transfer: %s", + formatter()); + } + return absl::OkStatus(); +} + +absl::Status ApplyTransferGuardToDeviceToHost( + absl::FunctionRef formatter) { + switch (GetTransferGuardActionForDeviceToHost()) { + case TransferGuardAction::kAllow: + break; + case TransferGuardAction::kLog: + LOG(WARNING) << "device-to-host transfer: " << formatter(); + break; + case TransferGuardAction::kDisallow: + return xla::InvalidArgument("Disallowed device-to-host transfer: %s", + formatter()); + } + return absl::OkStatus(); +} + +GarbageCollectionGuardLevel GetGarbageCollectArrayGuard() { + return thread_local_state.garbage_collect_array.value_or( + global_state.garbage_collect_array.value_or( + kDefaultGarbageCollectionGuardLevel)); +} + +void BuildGuardSubmodule(nb::module_ &m) { + nb::module_ glib = + m.def_submodule("guard_lib", "Jax support library for guards"); + + nb::enum_ tglevel(glib, "TransferGuardLevel"); + tglevel.value("ALLOW", TransferGuardLevel::kAllow); + tglevel.value("LOG", TransferGuardLevel::kLog); + tglevel.value("DISALLOW", TransferGuardLevel::kDisallow); + tglevel.value("LOG_EXPLICIT", TransferGuardLevel::kLogExplicit); + tglevel.value("DISALLOW_EXPLICIT", TransferGuardLevel::kDisallowExplicit); + + nb::enum_ gcglevel( + glib, "GarbageCollectionGuardLevel"); + gcglevel.value("ALLOW", GarbageCollectionGuardLevel::kAllow); + gcglevel.value("LOG", GarbageCollectionGuardLevel::kLog); + gcglevel.value("FATAL", GarbageCollectionGuardLevel::kFatal); + + nb::class_ tgstate(glib, "GuardState"); + tgstate.def_rw("host_to_device", &GuardState::host_to_device, + nb::arg().none()); + tgstate.def_rw("device_to_device", &GuardState::device_to_device, + nb::arg().none()); + tgstate.def_rw("device_to_host", &GuardState::device_to_host, + nb::arg().none()); + tgstate.def_rw("explicit_device_put", &GuardState::explicit_device_put); + tgstate.def_rw("explicit_device_get", &GuardState::explicit_device_get); + tgstate.def_rw("garbage_collect_array", &GuardState::garbage_collect_array, + nb::arg().none()); + + glib.def( + "global_state", [&]() { return &global_state; }, + nb::rv_policy::reference); + glib.def( + "thread_local_state", [&]() { return &thread_local_state; }, + nb::rv_policy::reference); +} + +} // namespace jax diff --git a/tests/ci_clangformat/guard_lib.h b/tests/ci_clangformat/guard_lib.h new file mode 100644 index 0000000..8624738 --- /dev/null +++ b/tests/ci_clangformat/guard_lib.h @@ -0,0 +1,115 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_GUARD_LIB_H_ +#define JAXLIB_GUARD_LIB_H_ + +#include +#include + +// placeholder for index annotation headers +#include "absl/functional/function_ref.h" +#include "absl/status/status.h" +#include "nanobind/nanobind.h" + +namespace jax { + +// Transfer guard level chosen by the user code. +enum class TransferGuardLevel { + // Explicit transfers: allow + // Implicit transfers: allow + kAllow, + // Explicit transfers: allow + // Implicit transfers: log + kLog, + // Explicit transfers: allow + // Implicit transfers: disallow + kDisallow, + // Explicit transfers: log + // Implicit transfers: log + kLogExplicit, + // Explicit transfers: disallow + // Implicit transfers: disallow + kDisallowExplicit, +}; + +// Garbage collection guard level chose by the user code. +enum class GarbageCollectionGuardLevel { + // Silently allow the object to be garbage collected. + kAllow, + // Log and allow the object to be garbage collected. + kLog, + // Fatal crash on object garbage collection. + kFatal, +}; + +// Flags for guard levels are controlled by: +// - a global flag value, +// e.g., associated to --jax_transfer_guard_device_to_host +// which defaults to TransferGuardLevel::kAllow. +// - possibly a thread-local value, which initially is std::nullopt and +// overrides the global value if set. The thread-local state is used to +// implement context managers that locally override the global state. +// +// Explicit device_put/device_get contexts are tracked by context managers. +struct GuardState { + std::optional host_to_device; + std::optional device_to_device; + std::optional device_to_host; + bool explicit_device_put = false; + bool explicit_device_get = false; + + std::optional garbage_collect_array; +}; + +// Resulting action for a transfer given the transfer guard level and the +// transfer type. +enum class TransferGuardAction { + // Silently allow the transfer. + kAllow, + // Log and allow the transfer. + kLog, + // Disallow the transfer. + kDisallow, +}; + +// Guards a host-to-device transfer. formatter is called to describe the +// transfer in a log message or error status. +// REQUIRES: Python GIL. +absl::Status ApplyTransferGuardToHostToDevice( + absl::FunctionRef formatter); + +// Guards a device-to-device transfer. formatter is called to describe the +// transfer in a log message or error status. +// REQUIRES: Python GIL. +absl::Status ApplyTransferGuardToDeviceToDevice( + absl::FunctionRef formatter); + +// Guards a device-to-host transfer. formatter is called to describe the +// transfer in a log message or error status. +// REQUIRES: Python GIL. +absl::Status ApplyTransferGuardToDeviceToHost( + absl::FunctionRef formatter); + +// Returns the garbage collection guard level for "jax.Array" objects. +// REQUIRES: Python GIL. +GarbageCollectionGuardLevel GetGarbageCollectArrayGuard(); + +// The function to call in `xla.cc` to add the bindings for this module. +void BuildGuardSubmodule(nanobind::module_ &m); + +} // namespace jax + +#endif // JAXLIB_GUARD_LIB_H_ diff --git a/tests/ci_clangformat/ifrt_proxy.cc b/tests/ci_clangformat/ifrt_proxy.cc new file mode 100644 index 0000000..de9ffb6 --- /dev/null +++ b/tests/ci_clangformat/ifrt_proxy.cc @@ -0,0 +1,162 @@ +// Copyright 2023 The JAX Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "jaxlib/ifrt_proxy.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/log/log_entry.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/function.h" // IWYU pragma: keep +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/unordered_map.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt_proxy/client/registry.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/statusor.h" + +namespace nb = ::nanobind; + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +struct PyClientConnectionOptions { + std::optional> on_disconnect; + std::optional> on_connection_update; + std::optional connection_timeout_in_seconds; + std::optional< + std::unordered_map>> + initialization_data; +}; + +absl::StatusOr> GetClient( + std::string proxy_server_address, + const PyClientConnectionOptions &py_options) { + DCHECK(PyGILState_Check()); + std::unique_ptr client; + + ClientConnectionOptions options; + if (py_options.on_disconnect) { + // While it is possible to pass around `py_options.on_disconnect` without + // wrapping it via a shared_ptr, copying the `py_options.on_disconnect` + // object can internally attempt to acquire the GIL [1], and can thus block + // or even deadlock. A unique_ptr or `absl::AnyInvocable` is not sufficient + // because downstream code can make copies. Reference: + // https://pybind11.readthedocs.io/en/stable/advanced/misc.html#common-sources-of-global-interpreter-lock-errors + auto py_on_disconnect = std::make_shared>( + std::move(*py_options.on_disconnect)); + + options.on_disconnect = + [on_disconnect = std::move(py_on_disconnect)](absl::Status s) mutable { + LOG(WARNING) << "Connection to server failed, calling supplied " + << "`on_disconnect` function: " << s; + tsl::Env::Default()->SchedClosure([s, on_disconnect]() mutable { + nb::gil_scoped_acquire gil_acquire; + (*on_disconnect)(s.ToString()); + on_disconnect = nullptr; + }); + }; + } + + if (py_options.on_connection_update) { + auto fn = std::make_shared>( + std::move(*py_options.on_connection_update)); + options.on_connection_update = [fn](absl::string_view log_line) -> void { + tsl::Env::Default()->SchedClosure([fn, str = std::string(log_line)] { + nb::gil_scoped_acquire gil_acquire; + (*fn)(std::string(str)); + }); + }; + } + + if (py_options.connection_timeout_in_seconds.has_value()) { + options.connection_timeout = + absl::Seconds(*py_options.connection_timeout_in_seconds); + } + + if (py_options.initialization_data.has_value()) { + AttributeMap::Map attribute_map; + for (const auto &[key, py_value] : *py_options.initialization_data) { + if (std::holds_alternative(py_value)) { + nb::bytes value = std::get(py_value); + attribute_map.insert({key, AttributeMap::StringValue(std::string( + value.c_str(), value.size()))}); + } else if (std::holds_alternative(py_value)) { + attribute_map.insert( + {key, AttributeMap::BoolValue(std::get(py_value))}); + } else { + CHECK(std::holds_alternative(py_value)); + attribute_map.insert( + {key, AttributeMap::Int64Value(std::get(py_value))}); + } + } + options.initialization_data = AttributeMap(std::move(attribute_map)); + } + + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(client, CreateClient(proxy_server_address, options)); + } + + // Constructing `xla::PyClient` requires GIL as it may dec-ref Python objects. + return xla::PyClient::Make(std::move(client)); +} + +} // namespace + +void BuildIfrtProxySubmodule(nb::module_ &m) { + nb::module_ sub_module = m.def_submodule("ifrt_proxy", "IFRT proxy"); + + nb::class_(sub_module, "ClientConnectionOptions") + .def(nb::init<>()) + .def_rw("on_disconnect", &PyClientConnectionOptions::on_disconnect, + nb::arg().none()) + .def_rw("on_connection_update", + &PyClientConnectionOptions::on_connection_update, + nb::arg().none()) + .def_rw("connection_timeout_in_seconds", + &PyClientConnectionOptions::connection_timeout_in_seconds, + nb::arg().none()) + .def_rw("initialization_data", + &PyClientConnectionOptions::initialization_data, + nb::arg().none()); + + sub_module.def("get_client", xla::ValueOrThrowWrapper(GetClient), + nb::arg("proxy_server_address"), nb::arg("options")); +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/tests/ci_clangformat/ifrt_proxy.h b/tests/ci_clangformat/ifrt_proxy.h new file mode 100644 index 0000000..59ab5da --- /dev/null +++ b/tests/ci_clangformat/ifrt_proxy.h @@ -0,0 +1,31 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_IFRT_PROXY_CLIENT_PY_MODULE_H_ +#define JAXLIB_IFRT_PROXY_CLIENT_PY_MODULE_H_ + +#include "nanobind/nanobind.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +void BuildIfrtProxySubmodule(nanobind::module_ &m); + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // JAXLIB_IFRT_PROXY_CLIENT_PY_MODULE_H_ diff --git a/tests/ci_clangformat/jax_jit.cc b/tests/ci_clangformat/jax_jit.cc new file mode 100644 index 0000000..f5a2989 --- /dev/null +++ b/tests/ci_clangformat/jax_jit.cc @@ -0,0 +1,495 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This files implements the `jax.jit` dispatch and just-in-time feature. +// +// In a nutshell, `Jit(f)` returns a callable that will dispatch (i.e. forward +// based on passed arguments dtypes/shapes/identity) the execution to a +// just-in-time compiled XLA Executable. All of that is done in C++ for +// performance reasons. +// +// This file contains the utilities to: +// (a) inspect arguments and describe their structure, dtype/shapes, etc. +// (b) keep a mapping from function signatures to compiled XLA Executables. + +#include "jaxlib/jax_jit.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/attributes.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "jaxlib/py_values.h" +#include "jaxlib/pytree.h" +#include "jaxlib/sharding.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "tsl/profiler/lib/traceme.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/nb_absl_inlined_vector.h" // IWYU pragma: keep +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/types.h" +#include "xla/tsl/platform/logging.h" + +namespace jax { + +namespace nb = nanobind; + +// TODO(phawkins): Add support for Tracers. +// TODO(jblespiau): Use absl absl::Status. + +namespace { + +// `thread_local_state.extra_jit_context` is set from Python. It's done when +// loading the Python jax modules on the main-thread. For other threads, we +// need to initialize the field the first time we access `thread_local_state`. +nb::object &initialize_local_state = *new nb::object(); + +} // namespace + +JitState &GlobalJitState() { + // Protected by the GIL. + static JitState &global_state = *new JitState(); + return global_state; +} + +JitState &ThreadLocalJitState() { + // TODO(phawkins): Google style guide forbids thread-local values with + // non-trivial destructors. + ABSL_CONST_INIT thread_local JitState thread_local_state; // NOLINT + DCHECK(PyGILState_Check()); + if (thread_local_state.extra_jit_context == std::nullopt) { + CHECK(initialize_local_state.ptr() != nullptr); + // Avoids reentrant calls to the initialization function. + thread_local_state.extra_jit_context = nb::none(); + initialize_local_state(); + } + return thread_local_state; +} + +bool GetDisableJit() { + auto &global_state = GlobalJitState(); + auto &thread_local_state = ThreadLocalJitState(); + CHECK(global_state.disable_jit.has_value()); + return thread_local_state.disable_jit.value_or(*global_state.disable_jit); +} + +bool GetEnableX64() { + auto &global_state = GlobalJitState(); + auto &thread_local_state = ThreadLocalJitState(); + CHECK(global_state.enable_x64.has_value()); + return thread_local_state.enable_x64.value_or(*global_state.enable_x64); +} + +std::optional GetDefaultDevice() { + auto &global_state = GlobalJitState(); + auto &thread_local_state = ThreadLocalJitState(); + return thread_local_state.default_device.has_value() + ? thread_local_state.default_device + : global_state.default_device; +} + +std::optional GetPostHook() { + auto &global_state = GlobalJitState(); + auto &thread_local_state = ThreadLocalJitState(); + return thread_local_state.post_hook.has_value() ? thread_local_state.post_hook + : global_state.post_hook; +} + +static std::string OptionalDebugString( + const std::optional optional) { + if (optional.has_value()) { + return nb::cast(nb::str(optional.value())); + } else { + return "None"; + } +} + +std::string ArgumentSignature::DebugString() const { + auto py_object_formatter = [](std::string *out, const nb::object &o) { + out->append(nb::cast(nb::str(o))); + }; + auto treedef_formatter = [](std::string *out, const xla::PyTreeDef &d) { + out->append(d.ToString()); + }; + return absl::StrFormat( + "static args (positional + keyword): [%s], " + "static arg keyword names: [%s], " + "dynamic arg signatures (positional + keyword): [%s]" + "dynamic arg shardings: [%s]", + absl::StrJoin(static_args, ",", py_object_formatter), + absl::StrJoin(static_arg_names, ",", py_object_formatter), + absl::StrJoin(dynamic_arg_names, ",", py_object_formatter), + absl::StrJoin(dynamic_arg_treedefs, "| ", treedef_formatter)); +} + +bool ArgumentSignature::operator==(const ArgumentSignature &other) const { + if (dynamic_arg_treedefs != other.dynamic_arg_treedefs) { + return false; + } + auto object_ptr_equality = [](nb::handle a, nb::handle b) { + return a.ptr() == b.ptr(); + }; + if (!absl::c_equal(dynamic_arg_names, other.dynamic_arg_names, + object_ptr_equality)) { + return false; + } + if (!absl::c_equal(static_arg_names, other.static_arg_names, + object_ptr_equality)) { + return false; + } + return absl::c_equal( + static_args, other.static_args, + [](const nb::object &a, const nb::object &b) { + try { + return a.type().ptr() == b.type().ptr() && a.equal(b); + } catch (const nb::python_error &e) { + throw std::invalid_argument(absl::StrCat( + "static arguments should be comparable using __eq__." + "The following error was raised when comparing two objects of " + "types ", + nb::cast(nb::str(a.type())), " and ", + nb::cast(nb::str(b.type())), + ". The error was:\n", e.what())); + } + }); +} + +std::string CallSignature::DebugString() const { + auto py_object_formatter = [](std::string *out, const nb::object &o) { + out->append(nb::cast(nb::str(o))); + }; + auto signature_formatter = [](std::string *out, + const xla::PyArgSignature &s) { + out->append(s.DebugString()); + }; + auto layout_formatter = [](std::string *out, + const std::shared_ptr &l) { + if (l != nullptr) { + out->append(l->ToString()); + } else { + out->append("None"); + } + }; + auto bool_formatter = [](std::string *out, bool o) { + out->append(o ? "true" : "false"); + }; + return absl::StrFormat( + "arg signature: %s\n" + "dynamic arg signatures (positional + keyword): %s\n" + "dynamic arg shardings: %s\n" + "dynamic arg layouts: %s\n" + "committed args: %s\n" + "device: %s\n" + "default_device: %s\n" + "jax_enable_x64: %d\n" + "global_extra_jit_context: %s\n" + "thread_local_extra_jit_context: %s\n" + "configs: %s\n", + arg_signature.DebugString(), + absl::StrJoin(dynamic_arg_signatures, ", ", signature_formatter), + absl::StrJoin(dynamic_arg_shardings, ", ", py_object_formatter), + absl::StrJoin(dynamic_arg_layouts, ", ", layout_formatter), + absl::StrJoin(committed_args, ",", bool_formatter), + device != nullptr ? device->DebugString() : "nullptr", + OptionalDebugString(default_device), jax_enable_x64, + OptionalDebugString(global_extra_jit_context), + OptionalDebugString(thread_local_extra_jit_context), + absl::StrJoin(configs, ", ", py_object_formatter)); +} + +bool CallSignature::operator==(const CallSignature &other) const { + if (arg_signature != other.arg_signature) { + return false; + } + if (dynamic_arg_signatures != other.dynamic_arg_signatures) { + return false; + } + if (device != other.device) { + return false; + } + if (jax_enable_x64 != other.jax_enable_x64) { + return false; + } + if (committed_args != other.committed_args) { + return false; + } + return + // `==` on py:objects is the Python `is`. We need equal. + absl::c_equal(dynamic_arg_shardings, other.dynamic_arg_shardings, + ShardingEqual) && + absl::c_equal(dynamic_arg_layouts, other.dynamic_arg_layouts, + [](const std::shared_ptr &a, + const std::shared_ptr &b) { + return (a && b) ? *a == *b : a == b; + }) && + (global_extra_jit_context.has_value() == + other.global_extra_jit_context.has_value()) && + (!global_extra_jit_context.has_value() || + global_extra_jit_context->equal(*other.global_extra_jit_context)) && + (default_device.has_value() == other.default_device.has_value()) && + (!default_device.has_value() || + default_device->equal(*other.default_device)) && + (thread_local_extra_jit_context.has_value() == + other.thread_local_extra_jit_context.has_value()) && + (!thread_local_extra_jit_context.has_value() || + thread_local_extra_jit_context->equal( + *other.thread_local_extra_jit_context)) && + configs.size() == other.configs.size() && + absl::c_equal( + configs, other.configs, + [](const nb::object &a, const nb::object &b) { return a.equal(b); }); +} + +// Filter out static arguments, flatten and concatenate other arguments (i.e. +// dynamic positional and keyword arguments), filling `arguments` in place. +absl::Status ParseArguments( + absl::Span positional_args, + absl::Span keyword_args, nb::handle kwnames, + absl::Span static_argnums, + absl::Span static_argnames, + xla::PyTreeRegistry *pytree_registry, ArgumentSignature &signature, + absl::InlinedVector &flat_dynamic_args) { + tsl::profiler::TraceMe traceme("ParseArguments"); + + DCHECK(absl::c_all_of(static_argnames, [](const nb::str &name) { + return PyUnicode_CHECK_INTERNED(name.ptr()); + })); + + flat_dynamic_args.reserve(positional_args.size() + keyword_args.size()); + if (static_argnums.empty()) { + signature.dynamic_arg_treedefs.reserve(positional_args.size()); + + // Positional arguments. + for (int i = 0; i < positional_args.size(); ++i) { + signature.dynamic_arg_treedefs.emplace_back(pytree_registry); + xla::PyTreeDef &pytree_def = signature.dynamic_arg_treedefs.back(); + pytree_def.Flatten(nb::handle(positional_args[i]), flat_dynamic_args); + } + } else { + signature.dynamic_arg_treedefs.reserve(positional_args.size()); + + // Positional arguments. + int num_positional_args = positional_args.size(); + for (int i = 0; i < positional_args.size(); ++i) { + if (std::find_if(static_argnums.begin(), static_argnums.end(), + [i, num_positional_args](int t) { + return t >= 0 ? i == t : i == t + num_positional_args; + }) == static_argnums.end()) { + signature.dynamic_arg_treedefs.emplace_back(pytree_registry); + xla::PyTreeDef &pytree_def = signature.dynamic_arg_treedefs.back(); + pytree_def.Flatten(positional_args[i], flat_dynamic_args); + } else { + signature.static_args.emplace_back( + nb::borrow(positional_args[i])); + } + } + } + + // Keyword arguments. + if (!keyword_args.empty()) { + std::vector> kwargs(keyword_args.size()); + // We first intern the keys, then sort them (by name, as in the Python path) + // (see also xla::PyTreeDef::Flatten) and then create the signatures. + // TODO(jblespiau): We should be able to sort the keys by interned-key + // pointers, but this requires the Python compilation to do the same. + for (int i = 0; i < keyword_args.size(); ++i) { + // Intern the key if not already interned. + PyObject *key = PyTuple_GET_ITEM(kwnames.ptr(), i); + Py_INCREF(key); + if (!PyUnicode_CHECK_INTERNED(key)) { + PyUnicode_InternInPlace(&key); + } + kwargs[i].first = key; + kwargs[i].second = keyword_args[i]; + } + + std::sort(kwargs.begin(), kwargs.end(), + [](const std::pair &a, + const std::pair &b) { + return a.first < b.first; + }); + auto kwarg_is_static = [&](nb::handle name) { + for (const auto &kw : static_argnames) { + if (kw.ptr() == name.ptr()) return true; + } + return false; + }; + + signature.dynamic_arg_names.reserve(keyword_args.size()); + for (int i = 0; i < keyword_args.size(); ++i) { + if (kwarg_is_static(kwargs[i].first)) { + signature.static_arg_names.push_back( + nb::steal(kwargs[i].first)); + signature.static_args.push_back( + nb::borrow(kwargs[i].second)); + } else { + signature.dynamic_arg_names.push_back( + nb::steal(kwargs[i].first)); + signature.dynamic_arg_treedefs.emplace_back(pytree_registry); + xla::PyTreeDef &pytree_def = signature.dynamic_arg_treedefs.back(); + pytree_def.Flatten(nb::handle(kwargs[i].second.ptr()), + flat_dynamic_args); + } + } + } + return absl::OkStatus(); +} + +void BuildJaxjitSubmodule(nb::module_ &m) { + nb::module_ jitlib = m.def_submodule("jax_jit", "Jax C++ jit library"); + + nb::class_ jit_state_(jitlib, "JitState"); + jit_state_.def_rw("disable_jit", &JitState::disable_jit, nb::arg().none()); + jit_state_.def_rw("enable_x64", &JitState::enable_x64, nb::arg().none()); + jit_state_.def_rw("default_device", &JitState::default_device, + nb::arg().none()); + jit_state_.def_rw("extra_jit_context", &JitState::extra_jit_context, + nb::arg().none()); + jit_state_.def_rw("post_hook", &JitState::post_hook, nb::arg().none()); + + jitlib.def( + "global_state", [&]() { return &GlobalJitState(); }, + nb::rv_policy::reference); + jitlib.def( + "thread_local_state", [&]() { return &ThreadLocalJitState(); }, + nb::rv_policy::reference); + + jitlib.def( + "swap_thread_local_state_disable_jit", + [&](std::optional value) -> std::optional { + auto tls = &ThreadLocalJitState(); + auto result = tls->disable_jit; + tls->disable_jit = value; + return result; + }, + nb::arg("value").none(), nb::rv_policy::reference); + + jitlib.def("get_enable_x64", &GetEnableX64); + jitlib.def("set_thread_local_state_initialization_callback", + [](nb::object f) { initialize_local_state = f; }); + + nb::class_ arg_signature(jitlib, "PyArgSignature"); + arg_signature + .def_prop_ro( + "dtype", + [](const xla::PyArgSignature &sig) { + return xla::ValueOrThrow(xla::PrimitiveTypeToNbDtype(sig.dtype)); + }) + .def_prop_ro("shape", + [](const xla::PyArgSignature &sig) { + return xla::SpanToNbTuple(absl::MakeConstSpan(sig.shape)); + }) + .def_ro("weak_type", &xla::PyArgSignature::weak_type); + jitlib.def("_ArgSignatureOfValue", + xla::ValueOrThrowWrapper(xla::PyArgSignatureOfValue)); + + jitlib.def("_is_float0", &xla::IsFloat0); + + nb::class_ argument_signature(jitlib, "ArgumentSignature"); + argument_signature.def_ro("static_args", &ArgumentSignature::static_args) + .def_ro("static_arg_names", &ArgumentSignature::static_arg_names) + .def_ro("dynamic_arg_names", &ArgumentSignature::dynamic_arg_names) + .def_ro("dynamic_arg_treedefs", &ArgumentSignature::dynamic_arg_treedefs) + .def("__repr__", &ArgumentSignature::DebugString) + .def("__str__", &ArgumentSignature::DebugString) + .def("__hash__", + [](const ArgumentSignature &s) { return absl::HashOf(s); }) + .def("__eq__", [](const ArgumentSignature &a, + const ArgumentSignature &b) { return a == b; }) + .def("__ne__", [](const ArgumentSignature &a, + const ArgumentSignature &b) { return a != b; }); + + jitlib.def( + "parse_arguments", + [](nb::sequence positional_args, nb::sequence keyword_args, + nb::tuple kwnames, absl::Span static_argnums, + absl::Span static_argnames, + xla::PyTreeRegistry *pytree_registry) { + ArgumentSignature signature; + absl::InlinedVector flat_dynamic_args; + nb::object positional_args_seq = nb::steal(PySequence_Fast( + positional_args.ptr(), "positional_args must be a list or tuple")); + if (!positional_args_seq.ptr()) { + throw nb::python_error(); + } + nb::object keyword_args_seq = nb::steal(PySequence_Fast( + keyword_args.ptr(), "keyword_args must be a list or tuple")); + if (!keyword_args_seq.ptr()) { + throw nb::python_error(); + } + absl::Span positional_args_span = + absl::MakeSpan(PySequence_Fast_ITEMS(positional_args_seq.ptr()), + PySequence_Fast_GET_SIZE(positional_args_seq.ptr())); + absl::Span keyword_args_span = + absl::MakeSpan(PySequence_Fast_ITEMS(keyword_args_seq.ptr()), + PySequence_Fast_GET_SIZE(keyword_args_seq.ptr())); + + // Intern the static argument names. + std::vector static_argnames_interned; + static_argnames_interned.reserve(static_argnames.size()); + for (const nb::str &name : static_argnames) { + PyObject *s = name.inc_ref().ptr(); + PyUnicode_InternInPlace(&s); + static_argnames_interned.push_back(nb::steal(s)); + } + + xla::ThrowIfError( + ParseArguments(positional_args_span, keyword_args_span, kwnames, + static_argnums, static_argnames_interned, + pytree_registry, signature, flat_dynamic_args)); + return std::make_pair(std::move(signature), + std::move(flat_dynamic_args)); + }, + nb::arg("positional_args"), nb::arg("keyword_args"), nb::arg("kwnames"), + nb::arg("static_argnums"), nb::arg("static_argnames"), + nb::arg("pytree_registry"), + R"doc(Parses the arguments to a function as jax.jit would. + +Returns a ArgumentSignature and the flattened dynamic arguments. + +Args: + positional_args: The positional arguments. + keyword_args: The keyword arguments. + kwnames: The keyword names. + static_argnums: The static argument numbers. + static_argnames: The static argument names. + pytree_registry: The pytree registry. +)doc"); +} + +} // namespace jax diff --git a/tests/ci_clangformat/jax_jit.h b/tests/ci_clangformat/jax_jit.h new file mode 100644 index 0000000..0c58fa8 --- /dev/null +++ b/tests/ci_clangformat/jax_jit.h @@ -0,0 +1,266 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_JAX_JIT_H_ +#define JAXLIB_JAX_JIT_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "jaxlib/py_values.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/pytree.h" +#include "jaxlib/sharding.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/tsl/platform/logging.h" + +namespace jax { + +// Flags, such as JIT disable and the x64 mode, are controlled by: +// - a global flag value, e.g., associated to --jax_enable_x64 +// - possibly a thread-local value, which initially is std::nullopt and +// overrides the global value if set. The thread-local state is +// used to implement context managers that locally override the global state. +struct JitState { + ~JitState() { + if (extra_jit_context) { + // We likely do not hold the GIL if this JitState is thread-local, so we + // hand the Python object to the global reference manager to destroy. + nanobind::object o = std::move(*extra_jit_context); + xla::GlobalPyRefManager()->AddGarbage(absl::MakeSpan(&o, 1)); + extra_jit_context = std::nullopt; + } + } + + std::optional disable_jit; + std::optional enable_x64; + + // Used to manually set the default device jax should use. May be unset even + // in global state, indicating there is no manual override. + // TODO(skyewm): make this a C++ type when all JAX backends support a single + // C++ device interface + std::optional default_device; + + // Extra context that should be included in the JIT cache key. Must be + // hashable and have an equality defined. + std::optional extra_jit_context; + + // A callback that, if present, is called when a JITted function is executed + // from cache. May be unset even in global state. + std::optional post_hook; +}; + +JitState &GlobalJitState(); + +// Requires the GIL. +JitState &ThreadLocalJitState(); + +// Getters for JitState fields that first look in thread-local state, then +// fallback to global state. +bool GetDisableJit(); +bool GetEnableX64(); + +// TODO(skyewm): return a C++ type when all JAX backends support a single C++ +// device interface +std::optional GetDefaultDevice(); +std::optional GetPostHook(); + +// An ArgumentSignature describes the static arguments to a function call, and +// how the dynamic arguments are related to the arguments. Together with the +// values of the dynamic arguments, this fully describes the arguments. +struct ArgumentSignature { + // A PyTreeDef for each dynamic argument, positional arguments first + // followed by keyword arguments. Keyword arguments are in the order given + // by dynamic_arg_names. + absl::InlinedVector dynamic_arg_treedefs; + + // Dynamic keyword argument names. Interned, and sorted by the keyword + // name. Interned values are safe to compare by pointer. + std::vector dynamic_arg_names; + + // Static arguments. Contains the positional arguments sorted in argument + // order, followed by static keyword arguments in the order given by + // `static_arg_names`. + std::vector static_args; + + // Static keyword argument names. Interned, and sorted by keyword name. + std::vector static_arg_names; + + bool operator==(const ArgumentSignature &other) const; + bool operator!=(const ArgumentSignature &other) const { + return !(*this == other); + } + + std::string DebugString() const; +}; + +template +H AbslHashValue(H h, const ArgumentSignature &s) { + h = H::combine(std::move(h), s.dynamic_arg_treedefs, + s.dynamic_arg_names.size(), s.static_args.size(), + s.static_arg_names.size()); + + for (const auto &name : s.dynamic_arg_names) { + h = H::combine(std::move(h), name.ptr()); + } + for (size_t i = 0; i < s.static_args.size(); ++i) { + const auto &static_arg = s.static_args[i]; + Py_hash_t hash; + try { + hash = nanobind::hash(static_arg); + } catch (const nanobind::python_error &e) { + if (!e.matches(PyExc_TypeError)) throw; + throw std::invalid_argument(absl::StrCat( + "Non-hashable static arguments are not supported. An error occurred " + "while trying to hash an object of type ", + nanobind::cast(nanobind::str(static_arg.type())), + ", ", nanobind::cast(nanobind::str(static_arg)), + ". The error was:\n", e.what(), "\n")); + } + h = H::combine(std::move(h), hash); + } + for (const auto &name : s.static_arg_names) { + h = H::combine(std::move(h), name.ptr()); + } + return h; +} + +// Filter out static arguments, flatten and concatenate other arguments (i.e. +// dynamic positional and keyword arguments), filling `arguments` in place. +// Args: +// positional_args: positional arguments +// keyword_args: the values of the keyword arguments +// kwnames: either None or a tuple containing the keyword argument names +// static_argnums: the indices of the static arguments in the positional +// arguments +// static_argnames: the names of the static arguments, which must be interned. +// pytree_registry: the registry to use to convert the arguments to pytrees +// signature: output; describes the static arguments and the identities of the +// dynamic arguments. +// flat_dynamic_args: output; the concatenation of the dynamic positional +// arguments and sorted keyword arguments. +absl::Status ParseArguments( + absl::Span positional_args, + absl::Span keyword_args, nanobind::handle kwnames, + absl::Span static_argnums, + absl::Span static_argnames, + xla::PyTreeRegistry *pytree_registry, ArgumentSignature &signature, + absl::InlinedVector &flat_dynamic_args); + +// The signature of Python jitted function call, partitioned into: +// - dynamic positional arguments (i.e. positional args which are not static) +// - static positional arguments (i.e. the args associated to static_argnums) +// - keyword arguments +// The CallSignature should unambiguously identify a function call, thus, +// equality is based on: +// (a) Same PyTree for all dynamic positional arguments and keyword arguments +// (a) equality of the arguments and keyword arguments ArgSignature +// (a) equality (delegated to Python) of the static arguments. +struct CallSignature { + // Not part of the signature, but we need it for error messages. + absl::string_view function_name; + + ArgumentSignature arg_signature; + + // Shape and dtype for both the dynamic positional arguments and the keyword + // arguments (sorted by keyword name). + absl::InlinedVector dynamic_arg_signatures; + + // The sharding of the jax.Array arguments. + std::vector dynamic_arg_shardings; + + // The layout of the jax.Array arguments. + std::vector> dynamic_arg_layouts; + + absl::InlinedVector committed_args; + + // For JIT, we need this in the key because computation follows the data, so + // we may have multiple executables depending on the devices the data is on. + // This is not the case for PMAP, and is set to `nullptr`. + xla::PjRtDevice *device = nullptr; + bool jax_enable_x64; + + // For JIT on PJIT, we need to fallback to python whenever default_device + // changes. + std::optional default_device; + + // Opaque additional context that should be included as part of the cache key. + std::optional global_extra_jit_context; + std::optional thread_local_extra_jit_context; + + std::vector configs; + + bool operator==(const CallSignature &other) const; + bool operator!=(const CallSignature &other) const { + return !(*this == other); + } + + std::string DebugString() const; +}; + +template +H AbslHashValue(H h, const CallSignature &s) { + h = H::combine(std::move(h), s.arg_signature, s.dynamic_arg_signatures); + + DCHECK(s.dynamic_arg_shardings.empty() || + s.dynamic_arg_shardings.size() == s.dynamic_arg_signatures.size()); + + DCHECK(s.dynamic_arg_layouts.empty() || + s.dynamic_arg_layouts.size() == s.dynamic_arg_signatures.size()); + + // TODO(chky): For now, we are only hashing the pointer of shardings to avoid + // slow python hashing function. Consider implementing hashing function and + // equality checks in C++ in jax::Sharding and use those here. + for (const auto &sharding : s.dynamic_arg_shardings) { + h = H::combine(std::move(h), ShardingHash(sharding)); + } + + for (const auto &layout : s.dynamic_arg_layouts) { + if (layout != nullptr) { + h = H::combine(std::move(h), *layout); + } + } + + h = H::combine(std::move(h), s.committed_args, s.device, s.jax_enable_x64); + + // We do not hash the extra_jit_context fields since calling Python hash + // functions is expensive (~300ns) and we don't expect a large number of + // different contexts. + return h; +} + +// The function to call in `xla.cc` to add the bindings for this module. +void BuildJaxjitSubmodule(nanobind::module_ &m); + +} // namespace jax + +#endif // JAXLIB_JAX_JIT_H_ diff --git a/tests/ci_clangformat/kernel_helpers.h b/tests/ci_clangformat/kernel_helpers.h new file mode 100644 index 0000000..322e4ea --- /dev/null +++ b/tests/ci_clangformat/kernel_helpers.h @@ -0,0 +1,50 @@ +/* Copyright 2019 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_KERNEL_HELPERS_H_ +#define JAXLIB_KERNEL_HELPERS_H_ + +#include +#include + +#include "absl/base/casts.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" + +namespace jax { + +// See kernel_nanobind_helpers.h for info on descriptor objects. We separate out +// the functionality that doesn't require nanobind for building CUDA libraries, +// since older versions nvcc don't seem to be able to compile nanobind. + +// Packs a descriptor object into a byte string. +template +std::string PackDescriptorAsString(const T &descriptor) { + return std::string(absl::bit_cast(&descriptor), sizeof(T)); +} + +// Unpacks a descriptor object from a byte string. +template +absl::StatusOr UnpackDescriptor(const char *opaque, + std::size_t opaque_len) { + if (opaque_len != sizeof(T)) { + return absl::InternalError("Invalid size for operation descriptor."); + } + return absl::bit_cast(opaque); +} + +} // namespace jax + +#endif // JAXLIB_KERNEL_HELPERS_H_ diff --git a/tests/ci_clangformat/kernel_nanobind_helpers.h b/tests/ci_clangformat/kernel_nanobind_helpers.h new file mode 100644 index 0000000..88c9124 --- /dev/null +++ b/tests/ci_clangformat/kernel_nanobind_helpers.h @@ -0,0 +1,72 @@ +/* Copyright 2019 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_KERNEL_NANOBIND_HELPERS_H_ +#define JAXLIB_KERNEL_NANOBIND_HELPERS_H_ + +#include +#include + +#include "absl/base/casts.h" +#include "jaxlib/kernel_helpers.h" +#include "nanobind/nanobind.h" +#include "xla/ffi/api/c_api.h" +#include "xla/tsl/python/lib/core/numpy.h" // NOLINT + +namespace jax { + +// Caution: to use this type you must call tsl::ImportNumpy() in your module +// initialization function. Otherwise PyArray_DescrCheck will be nullptr. +class dtype : public nanobind::object { + public: + NB_OBJECT_DEFAULT(dtype, object, "dtype", PyArray_DescrCheck); // NOLINT + + int itemsize() const { return nanobind::cast(attr("itemsize")); } + + /// Single-character code for dtype's kind. + /// For example, floating point types are 'f' and integral types are 'i'. + char kind() const { return nanobind::cast(attr("kind")); } +}; + +// Descriptor objects are opaque host-side objects used to pass data from JAX +// to the custom kernel launched by XLA. Currently simply treat host-side +// structures as byte-strings; this is not portable across architectures. If +// portability is needed, we could switch to using a representation such as +// protocol buffers or flatbuffers. + +// Packs a descriptor object into a nanobind::bytes structure. +// UnpackDescriptor() is available in kernel_helpers.h. +template +nanobind::bytes PackDescriptor(const T &descriptor) { + std::string s = PackDescriptorAsString(descriptor); + return nanobind::bytes(s.data(), s.size()); +} + +template +nanobind::capsule EncapsulateFunction(T *fn) { + return nanobind::capsule(absl::bit_cast(fn), + "xla._CUSTOM_CALL_TARGET"); +} + +template +nanobind::capsule EncapsulateFfiHandler(T *fn) { + static_assert(std::is_invocable_r_v, + "Encapsulated function must be an XLA FFI handler"); + return nanobind::capsule(absl::bit_cast(fn)); +} + +} // namespace jax + +#endif // JAXLIB_KERNEL_NANOBIND_HELPERS_H_ diff --git a/tests/ci_clangformat/mlir.cc b/tests/ci_clangformat/mlir.cc new file mode 100644 index 0000000..1d73319 --- /dev/null +++ b/tests/ci_clangformat/mlir.cc @@ -0,0 +1,236 @@ +/* Copyright 2021 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/mlir.h" + +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "stablehlo/dialect/Serialization.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/translate/stablehlo.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/refine_polymorphic_shapes.h" +#include "xla/service/hlo.pb.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "llvm/Support/raw_ostream.h" + +namespace nb = nanobind; + +namespace xla { +namespace { + +std::string PrintModule(mlir::ModuleOp module) { + std::string s; + llvm::raw_string_ostream os(s); + mlir::OpPrintingFlags flags; + flags.enableDebugInfo(); + module->print(os, flags); + return s; +} + +absl::StatusOr SerializeUsingBytecode(mlir::ModuleOp module) { + std::string bytecode; + llvm::raw_string_ostream os(bytecode); + mlir::BytecodeWriterConfig config; + if (mlir::failed(mlir::writeBytecodeToFile(module, os, config))) { + return absl::InvalidArgumentError("mlir::writeBytecodeToFile failed"); + } + return bytecode; +} + +void EnablePrintBeforeAndAfter(mlir::PassManager &pm) { + auto print_before = [](mlir::Pass *, mlir::Operation *) { return true; }; + auto print_after = [](mlir::Pass *, mlir::Operation *) { return true; }; + pm.enableIRPrinting(print_before, print_after); +} + +absl::StatusOr HloToStableHlo(const nb::bytes &hlo_module_proto) { + mlir::MLIRContext context; + if (VLOG_IS_ON(3)) context.disableMultithreading(); + HloModuleProto proto; + proto.ParseFromArray(hlo_module_proto.c_str(), hlo_module_proto.size()); + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ConvertHloToStablehlo(context, &proto)); + TF_ASSIGN_OR_RETURN(std::string bytecode, SerializeUsingBytecode(*module)); + return nb::bytes(bytecode.data(), bytecode.size()); +} + +// Converts an XlaComputation to a StableHLO mlir::Module string. +// Exists for backwards compatibility. +// TODO(phawkins): port remaining users of XlaComputations to use mlir::Modules +// instead and delete this function. +absl::StatusOr PyXlaComputationToMlirModule( + const XlaComputation &computation) { + mlir::MLIRContext context; + if (VLOG_IS_ON(3)) context.disableMultithreading(); + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ConvertHloToStablehlo(context, &computation.proto())); + return PrintModule(*module); +} + +absl::StatusOr PyMlirModuleToXlaComputation( + absl::string_view mlir_module, bool use_tuple_args, bool return_tuple) { + mlir::MLIRContext context; + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + XlaComputation computation; + // SDY dialect may be part of the module which XLA doesn't know about. + TF_RETURN_IF_ERROR(ExportShardyForHloRoundTrip(*module)); + TF_RETURN_IF_ERROR(MlirToXlaComputation(*module, computation, use_tuple_args, + return_tuple, + /*use_shardy=*/false)); + return computation; +} + +absl::StatusOr PyMhloToStablehlo(absl::string_view mlir_module) { + mlir::MLIRContext context; + if (VLOG_IS_ON(3)) context.disableMultithreading(); + // JAX can be customized in a way that involves operations from custom + // dialects showing up in JAX IR. + // `ParseMlirModuleString` won't know about these dialects, but that's fine + // since we just want to convert MHLO ops to StableHLO ops here and leave + // everything else unchanged. + // In order to achieve that, we're allowing unregistered dialects here. + context.allowUnregisteredDialects(true); + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + mlir::PassManager pm(&context); + if (VLOG_IS_ON(3)) EnablePrintBeforeAndAfter(pm); + pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); + if (!mlir::succeeded(pm.run(*module))) { + return tsl::errors::InvalidArgument("MHLO => StableHLO failed"); + } + // Use bytecode, passing unregistered dialects with properties causes issues + // when using textual assembly. + TF_ASSIGN_OR_RETURN(std::string bytecode, SerializeUsingBytecode(*module)); + return nb::bytes(bytecode.data(), bytecode.size()); +} + +absl::StatusOr PySerializePortableArtifact( + absl::string_view mlir_module, absl::string_view target) { + mlir::MLIRContext context; + if (VLOG_IS_ON(3)) context.disableMultithreading(); + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + + // Serialize portable artifact + TF_ASSIGN_OR_RETURN( + std::string bytecode, + SerializeUsingVersionedStablehlo(*module, target, /*inplace=*/true)); + return nb::bytes(bytecode.data(), bytecode.size()); +} + +absl::StatusOr PyDeserializePortableArtifact( + const nb::bytes &bytecode_str) { + mlir::MLIRContext context; + mlir::OwningOpRef module = + mlir::stablehlo::deserializePortableArtifact( + absl::string_view(bytecode_str.c_str(), bytecode_str.size()), + &context); + if (!module) + return tsl::errors::InvalidArgument("Failed to deserialize StableHLO"); + return PrintModule(*module); +} + +} // namespace + +void BuildMlirSubmodule(nb::module_ &m) { + nb::module_ mlir_module = m.def_submodule("mlir", "MLIR/XLA integration"); + + mlir_module.def("hlo_to_stablehlo", xla::ValueOrThrowWrapper(HloToStableHlo), + nb::arg("computation")); + + mlir_module.def("xla_computation_to_mlir_module", + xla::ValueOrThrowWrapper(PyXlaComputationToMlirModule), + nb::arg("computation")); + mlir_module.def( + "mlir_module_to_xla_computation", + [](const nb::bytes &bytecode, bool use_tuple_args, bool return_tuple) { + return xla::ValueOrThrow(PyMlirModuleToXlaComputation( + absl::string_view(bytecode.c_str(), bytecode.size()), + use_tuple_args, return_tuple)); + }, + nb::arg("mlir_module"), nb::arg("use_tuple_args") = false, + nb::arg("return_tuple") = false); + mlir_module.def("mlir_module_to_xla_computation", + xla::ValueOrThrowWrapper(PyMlirModuleToXlaComputation), + nb::arg("mlir_module"), nb::arg("use_tuple_args") = false, + nb::arg("return_tuple") = false); + mlir_module.def( + "mhlo_to_stablehlo", + [](const nb::bytes &bytecode) { + return xla::ValueOrThrow(PyMhloToStablehlo( + absl::string_view(bytecode.c_str(), bytecode.size()))); + }, + nb::arg("mlir_module")); + mlir_module.def("mhlo_to_stablehlo", + xla::ValueOrThrowWrapper(PyMhloToStablehlo), + nb::arg("mlir_module")); + mlir_module.def( + "serialize_portable_artifact", + [](const nb::bytes &bytecode, absl::string_view target) { + return xla::ValueOrThrow(PySerializePortableArtifact( + absl::string_view(bytecode.c_str(), bytecode.size()), target)); + }, + nb::arg("mlir_module"), nb::arg("target")); + mlir_module.def("serialize_portable_artifact", + xla::ValueOrThrowWrapper(PySerializePortableArtifact), + nb::arg("mlir_module"), nb::arg("target")); + mlir_module.def("deserialize_portable_artifact", + xla::ValueOrThrowWrapper(PyDeserializePortableArtifact), + nb::arg("mlir_module")); + mlir_module.def( + "refine_polymorphic_shapes", + [](nb::bytes bytecode, bool enable_shape_assertions, + bool validate_static_shapes, bool enable_shardy) -> nb::bytes { + std::string buffer; + llvm::raw_string_ostream os(buffer); + xla::ThrowIfError(RefinePolymorphicShapes( + absl::string_view(bytecode.c_str(), bytecode.size()), os, + enable_shape_assertions, validate_static_shapes, enable_shardy)); + return nb::bytes(buffer.data(), buffer.size()); + }, + nb::arg("mlir_module"), nb::arg("enable_shape_assertions") = true, + nb::arg("validate_static_shapes") = true, + nb::arg("enable_shardy") = false, + R"(Refines the dynamic shapes for a module. + The "main" function must have static shapes and all the + intermediate dynamic shapes depend only on the input static + shapes. Optionally, also validates that the resulting module has + only static shapes. + )"); +} + +} // namespace xla diff --git a/tests/ci_clangformat/mlir.h b/tests/ci_clangformat/mlir.h new file mode 100644 index 0000000..cd4803b --- /dev/null +++ b/tests/ci_clangformat/mlir.h @@ -0,0 +1,28 @@ +/* Copyright 2021 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_MLIR_H_ +#define JAXLIB_MLIR_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace xla { + +void BuildMlirSubmodule(nanobind::module_ &m); + +} // namespace xla + +#endif // JAXLIB_MLIR_H_ diff --git a/tests/ci_clangformat/nb_class_ptr.h b/tests/ci_clangformat/nb_class_ptr.h new file mode 100644 index 0000000..368e31f --- /dev/null +++ b/tests/ci_clangformat/nb_class_ptr.h @@ -0,0 +1,59 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_NB_CLASS_PTR_H_ +#define JAXLIB_NB_CLASS_PTR_H_ + +#include "nanobind/nanobind.h" + +namespace xla { + +// A reference-counting smart pointer to a nanobind-wrapped class on the Python +// heap. Type T must be a class known to nanobind via a nanobind::class_ +// declaration. nb_class_ptr is useful for managing C++ classes that may be +// allocated inline in Python objects on the Python heap. +template +class nb_class_ptr : public nanobind::object { + public: + inline nb_class_ptr() : nanobind::object() {} + inline nb_class_ptr(nanobind::handle h, ::nanobind::detail::borrow_t) + : nanobind::object(h, ::nanobind::detail::borrow_t{}) {} + inline nb_class_ptr(nanobind::handle h, ::nanobind::detail::steal_t) + : nanobind::object(h, ::nanobind::detail::steal_t{}) {} + inline static bool check_(nanobind::handle h) { + nanobind::handle type = nanobind::type(); + return h.type().is(type); + }; + + T *operator->() const { return nanobind::inst_ptr(ptr()); } + T &operator*() const { return *nanobind::inst_ptr(ptr()); } + T *get() const { return ptr() ? nanobind::inst_ptr(ptr()) : nullptr; } +}; + +// This function is analogous to std::make_unique(...), but instead it +// allocates the object on the Python heap +template +nb_class_ptr make_nb_class(Args &&...args) { + nanobind::handle type = nanobind::type(); + nanobind::object instance = nanobind::inst_alloc(type); + T *ptr = nanobind::inst_ptr(instance); + new (ptr) T(std::forward(args)...); + nanobind::inst_mark_ready(instance); + return nb_class_ptr(instance.release(), ::nanobind::detail::steal_t{}); +} + +} // namespace xla + +#endif // JAXLIB_NB_CLASS_PTR_H_ diff --git a/tests/ci_clangformat/pjit.cc b/tests/ci_clangformat/pjit.cc new file mode 100644 index 0000000..ff5e69c --- /dev/null +++ b/tests/ci_clangformat/pjit.cc @@ -0,0 +1,1401 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/pjit.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/cleanup/cleanup.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" +#include "absl/types/span.h" +#include "jaxlib/config.h" +#include "jaxlib/guard_lib.h" +#include "jaxlib/jax_jit.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_executable.h" +#include "jaxlib/py_values.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/pytree.h" +#include "jaxlib/sharding.h" +#include "jaxlib/traceback.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "tsl/profiler/lib/traceme.h" +#include "xla/layout.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/lru_cache.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_helpers.h" +#include "xla/python/nb_numpy.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace jax { +namespace { + +namespace nb = nanobind; + +struct PjitCacheEntry { + explicit PjitCacheEntry(xla::PyTreeRegistry *registry) + : out_pytree_def(registry) {} + std::shared_ptr executable; + std::vector in_shardings; + std::vector out_avals; + std::vector out_dtypes; + std::vector> out_shapes; + std::vector out_weak_types; + std::vector out_shardings; + std::vector out_committed; + xla::PyTreeDef out_pytree_def; + // Bitvector of kept arguments from Jaxpr DCE pass. Used to drop some `args` + // in PjitFunction::Call before calling into compiled computation. + std::vector kept_var_bitvec; + std::vector in_device_local_layouts; + + // Ensures a single thread performs the compilation for a given executable. + // + // The first thread (holding the GIL) will create the CacheEntry associated to + // a signature and if the object has been inserted already, other threads + // will wait for the notification. + absl::Notification compilation_complete; + + std::thread::id thread_id = std::this_thread::get_id(); + + bool fall_back_to_python = false; +}; + +// A PjitFunctionCache represents a cache of compiled functions that can be +// shared between one or more PjitFunction objects. It serves two goals: +// - reduce the number of lru caches (hash map) across multiple JITs. +// - make the cache global to increase cache hits (e.g. calling jit(f)(3) twice) +// keeping entries alive as long as the underlying function f is alive. +// Assume the cache is protected by the GIL. +class PjitFunctionCache { + public: + static constexpr int kDefaultCapacity = 4096; + explicit PjitFunctionCache(int capacity); + + // Cache entries are shared_ptr<>s because it's possible the cache entry + // might be evicted before we finish tracing/compiling. + typedef xla::LRUCache> Cache; + + // We include as part of the cache key `global_cache_key` (and any other + // fields that aren't subsumed by the CallSignature we compute for each call). + static std::shared_ptr Lookup( + xla::nb_class_ptr self, nb::handle function, + nb::object global_cache_key); + std::shared_ptr DefaultCache(); + + // These methods require the GIL or the object's lock in no-GIL mode. + int Size() const { return lru_list_.Size(); } + int Capacity() const { return lru_list_.Capacity(); } + void Clear() { + lru_list_.Clear(); + functions_.clear(); + } + + private: + struct Key { + nb::handle function; // Does not hold a reference. + + // Other fields that are part of the arguments to `jit`, but are not + // otherwise part of CallSignature. + nb::object global_cache_key; + + size_t cached_hash; + + bool operator==(const Key &other) const { + bool global_cache_eq; + try { + global_cache_eq = global_cache_key.equal(other.global_cache_key); + } catch (const nanobind::python_error &e) { + throw std::invalid_argument( + absl::StrCat("Equality of global cache key lead to an exception. " + "The error was:\n", + e.what(), "\n")); + } + return function.ptr() == other.function.ptr() && global_cache_eq; + } + + struct Hash { + size_t operator()(const Key &key) const { return key.cached_hash; } + }; + }; + + template + friend H AbslHashValue(H h, const Key &key) { + h = H::combine(std::move(h), key.function.ptr()); + Py_hash_t hash; + try { + hash = nb::hash(key.global_cache_key); + } catch (const nanobind::python_error &e) { + if (!e.matches(PyExc_TypeError)) throw; + throw std::invalid_argument(absl::StrCat( + "Hashing global cache key lead to an exception. The error was:\n", + e.what(), "\n")); + } + h = H::combine(std::move(h), hash); + return h; + } + + struct Value { + explicit Value(std::shared_ptr cache) : cache(std::move(cache)) {} + std::shared_ptr cache; + + // A weak reference to the key function. We use the weak reference to + // register a callback that is triggered when the key function is destroyed. + // We use a weak pointer because we want to allow caching across multiple + // calls to `pjit(f)` if `f` remains alive, but we do not want the cache + // to keep `f` alive if all other references are dropped. + std::optional weakref; + }; + + // lru_list_ and functions_ are protected by the GIL in GIL mode, and by the + // self object lock in freethreading mode. + Cache::LRUList lru_list_; + // We use std::unordered_map because ABSL containers are not exception safe: + std::unordered_map, Key::Hash> functions_; + // mu_ prevents concurrent insertions into functions_ if the gil or critical + // section lock is released during insertion. + absl::Mutex mu_; +}; + +PjitFunctionCache::PjitFunctionCache(int capacity) : lru_list_(capacity) {} + +std::shared_ptr PjitFunctionCache::DefaultCache() { + return std::make_shared(&lru_list_); +} + +/*static*/ std::shared_ptr PjitFunctionCache::Lookup( + xla::nb_class_ptr self, nb::handle function, + nb::object global_cache_key) ABSL_NO_THREAD_SAFETY_ANALYSIS { + // In no-GIL mode, a critical section on self plays the same role that + // the GIL plays in GIL mode. + nb::ft_object_guard lock(self); + { + // Because the gil (or the critical section lock) can be released during + // cache insertion, this forces the lock order to be mu_ then gil so we + // must release the gil first. + nb::gil_scoped_release release; + // Acquire a mutex to avoid problems where the gil is released during + // cache insertion and then a second thread invalidates the cache order. + self->mu_.Lock(); + } + absl::Cleanup unlock = [&self]() ABSL_UNLOCK_FUNCTION(self->mu_) { + self->mu_.Unlock(); + }; + Key key; + key.function = function; + key.global_cache_key = global_cache_key; + key.cached_hash = absl::HashOf(key); + auto insert = self->functions_.emplace(key, nullptr); + if (!insert.second) { + return insert.first->second->cache; + } + std::shared_ptr cache = std::make_shared(&self->lru_list_); + auto callback = + nb::cpp_function([self, key{std::move(key)}](nb::handle weakref) { + nb::ft_object_guard lock(self); + auto it = self->functions_.find(key); + if (it == self->functions_.end()) { + return; + } + // Remove the value from the map before destroying it. Destroying + // the value may release `lock` since it may call arbitrary Python + // code. + std::unique_ptr value = std::move(it->second); + self->functions_.erase(it); + value.reset(); + }); + PyObject *weakref = PyWeakref_NewRef(function.ptr(), callback.ptr()); + if (weakref) { + std::unique_ptr &entry = insert.first->second; + entry = std::make_unique(cache); + entry->weakref = nb::steal(weakref); + } else { + PyErr_Clear(); + // `function` is not weak-referenceable. Don't bother adding it to the + // shared cache in that case; the `jit` object will hold the only shared + // reference to the cache entry. + self->functions_.erase(insert.first); + } + return cache; +} + +class PjitFunction { + public: + PjitFunction(std::string function_name, std::optional fun, + nb::callable cache_miss, std::vector static_argnums, + std::vector static_argnames, + nb::object global_cache_key, + xla::nb_class_ptr pytree_registry, + nb::callable shard_arg_fallback, + xla::nb_class_ptr cache); + ~PjitFunction(); + + PjitFunction(const PjitFunction &) = delete; + PjitFunction &operator=(const PjitFunction &) = delete; + PjitFunction(PjitFunction &&) = default; + PjitFunction &operator=(PjitFunction &&) = default; + + // nb::object typed subclass for PjitFunction objects. + class pyobject : public nb::object { + public: + NB_OBJECT(pyobject, nb::object, "PjitFunction", + PjitFunction::IsPjitFunction); + pyobject() = default; + PjitFunction *func() const { + return PjitFunction::AsPjitFunctionUnchecked(*this); + } + }; + // Alias as ::object; outside the scope above we won't confuse nanobind's + // macros. + using object = pyobject; + + // Returns true if `h` is a PjitFunction. + static bool IsPjitFunction(nb::handle handle); + // Converts `handle` to a PjitFunction*. Does not do any checking. + static PjitFunction *AsPjitFunctionUnchecked(nb::handle handle); + + absl::StatusOr Call(nb::handle callable, PyObject *const *args, + size_t nargs, PyObject *kwnames); + + void InitExecutables(); + + void ClearPythonReferences(); + + const std::string &function_name() const { return function_name_; } + const std::optional &fun() const { return fun_; } + const nb::callable &cache_miss() const { return cache_miss_; } + const xla::nb_class_ptr &pytree_registry() const { + return pytree_registry_; + } + const nb::callable &shard_arg_fallback() const { return shard_arg_fallback_; } + + const std::vector &static_argnums() const { return static_argnums_; } + const std::vector &static_argnames() const { + return static_argnames_; + } + const nb::object &global_cache_key() const { return global_cache_key_; } + const xla::nb_class_ptr &cache() const { return cache_; } + + int cache_capacity() const { + nb::ft_object_guard lock(cache_); + return executables_->Size(); + } + + void ClearCache() { + nb::ft_object_guard lock(cache_); + executables_->Clear(); + } + + std::shared_ptr executables() { + nb::ft_object_guard lock(cache_); + return executables_; + } + + nb::object PythonSignature() { + if (!fun_.has_value()) { + throw nb::value_error( + absl::StrFormat( + "Calling __signature__ on PjitFunction(%s) not supported.", + function_name_) + .c_str()); + } + static const auto *inspect = + new nb::module_(nb::module_::import_("inspect")); + return inspect->attr("signature")(*fun_); + } + + private: + absl::Status ComputeCallSignature( + absl::Span flat_dynamic_args, + CallSignature &call_signature); + + void PopulateCacheEntry(PjitCacheEntry &cache_entry, + const nb::tuple &out_and_fastpath_data); + + std::string function_name_; + std::optional fun_; + nb::callable cache_miss_; + std::vector static_argnums_; + std::vector static_argnames_; + nb::object global_cache_key_; + + xla::nb_class_ptr pytree_registry_; + nb::callable shard_arg_fallback_; + xla::nb_class_ptr cache_; + + // In no-GIL mode executables_ is protected by the object lock on cache_, + // because it shared an LRU list with cache_. + std::shared_ptr executables_; +}; + +PjitFunction::PjitFunction( + std::string function_name, std::optional fun, + nb::callable cache_miss, std::vector static_argnums, + std::vector static_argnames, nb::object global_cache_key, + xla::nb_class_ptr pytree_registry, + nb::callable shard_arg_fallback, xla::nb_class_ptr cache) + : function_name_(std::move(function_name)), + fun_(std::move(fun)), + cache_miss_(std::move(cache_miss)), + static_argnums_(std::move(static_argnums)), + global_cache_key_(std::move(global_cache_key)), + pytree_registry_(std::move(pytree_registry)), + shard_arg_fallback_(std::move(shard_arg_fallback)), + cache_(std::move(cache)) { + std::sort(static_argnums_.begin(), static_argnums_.end()); + static_argnames_.reserve(static_argnames.size()); + for (nb::str &name : static_argnames) { + PyObject *s = name.inc_ref().ptr(); + PyUnicode_InternInPlace(&s); + static_argnames_.push_back(nb::steal(s)); + } +} + +void PjitFunction::InitExecutables() { + // Construction of the object hasn't completed yet, so we don't need to hold + // the cache lock to mutate executables_. + if (!fun_.has_value()) { + executables_ = cache_->DefaultCache(); + } else { + executables_ = cache_->Lookup(cache_, fun_.value(), global_cache_key_); + } +} + +PjitFunction::~PjitFunction() { + nb::ft_object_guard lock(cache_); + executables_ = nullptr; +} + +void CallShardArgFallback(nb::handle arg, nb::handle sharding, + nb::handle layout, const nb::callable &fallback, + std::vector &num_args_arrays, + std::vector &keep_alive_objects) { + tsl::profiler::TraceMe traceme("cpp_pjit_shard_arg_fallback"); + auto py_array_or_bufs = fallback(arg, sharding, layout); + auto py_array = nb::cast(py_array_or_bufs); + num_args_arrays.push_back(tsl::FormRef(py_array.ifrt_array())); + keep_alive_objects.push_back(std::move(py_array_or_bufs)); +} + +// Prepares the input PjRtBuffers from the python arguments. This is equivalent +// to shard_args() in pxla.py but for only a few supported cases. +absl::StatusOr> PrepareIfrtInputs( + const xla::PyLoadedExecutable &executable, + absl::Span flat_dynamic_args, + absl::Span flat_dynamic_arg_signatures, + bool enable_x64, const std::vector &kept_args, + const std::vector &in_shardings, + const std::vector &in_device_local_layouts, + const nb::callable &shard_arg_fallback, + std::vector &keep_alive_objects) { + const auto &addressable_devices = + executable.ifrt_loaded_executable()->addressable_devices(); + const auto &num_global_devices = + executable.ifrt_loaded_executable()->num_devices(); + int num_args = flat_dynamic_args.size(); + + std::vector num_args_arrays; + num_args_arrays.reserve(num_args); + + struct CopyGroup { + std::vector indices; + std::vector arrays; + }; + absl::flat_hash_map, + CopyGroup> + copy_groups; + + xla::DevicePutOptions options; + options.squash_64bit_types = !enable_x64; + options.allow_zero_copy = true; + xla::ifrt::Device *data_device = nullptr; + if (executable.ifrt_loaded_executable()->num_devices() == 1) { + data_device = executable.ifrt_loaded_executable()->addressable_devices()[0]; + } + int dce_i = 0; + for (int i = 0; i < num_args; ++i) { + if (!kept_args[i]) { + continue; + } + int dce_index = dce_i; + ++dce_i; + + const nb::object &arg = flat_dynamic_args[i]; + const nb::object &in_device_local_layout = + in_device_local_layouts[dce_index]; + + auto transfer_guard_formatter = [] { return std::string(""); }; + + if (arg.type().ptr() != xla::PyArray::type().ptr()) { + if (data_device != nullptr && in_device_local_layout.is_none()) { + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToHostToDevice(transfer_guard_formatter)); + TF_ASSIGN_OR_RETURN( + auto device_put_result, + DevicePutWithDevice(arg, + executable.ifrt_loaded_executable()->client(), + data_device, xla::ifrt::MemoryKind(), options)); + num_args_arrays.push_back(std::move(device_put_result.ifrt_array)); + continue; + } else { + CallShardArgFallback(arg, in_shardings[dce_index], + in_device_local_layout, shard_arg_fallback, + num_args_arrays, keep_alive_objects); + continue; + } + } + + xla::PyArray py_array = nb::borrow(arg); + const auto &sharding = py_array.sharding(); + int sharding_num_devices = jax::Sharding::SafeNumDevices(sharding); + + // Currently only committed PyArray inputs or uncommitted PyArray on a + // single device inputs are allowed. This is checked previously in the entry + // point of PjitFunction::Call(). + DCHECK(py_array.committed() || + (!py_array.committed() && sharding_num_devices == 1)); + + if (!in_device_local_layout.is_none()) { + TF_ASSIGN_OR_RETURN(auto arr_layout, py_array.ifrt_array()->layout()); + xla::Layout in_xc_layout = nb::cast( + in_device_local_layout.attr("_to_xla_layout")(py_array.dtype())); + if (in_xc_layout != arr_layout->xla_layout()) { + CallShardArgFallback(arg, in_shardings[dce_index], + in_device_local_layout, shard_arg_fallback, + num_args_arrays, keep_alive_objects); + continue; + } + } + + if (sharding.type().ptr() == jax::PmapSharding::type().ptr()) { + CallShardArgFallback(arg, in_shardings[dce_index], in_device_local_layout, + shard_arg_fallback, num_args_arrays, + keep_alive_objects); + continue; + } + + if (sharding_num_devices != num_global_devices) { + CallShardArgFallback(arg, in_shardings[dce_index], in_device_local_layout, + shard_arg_fallback, num_args_arrays, + keep_alive_objects); + continue; + } + + xla::ifrt::Array *ifrt_array = py_array.ifrt_array(); + // PyArray inputs should have already been checked in + // `xla::PyArgSignatureOfValue()` called by + // `PjitFunction::ComputeCallSignature()`. + DCHECK(ifrt_array != nullptr) << "PyArray has been unexpectedly deleted."; + + const auto &ifrt_sharding = ifrt_array->sharding(); + if (sharding_num_devices == 1 && + ifrt_sharding.devices()->devices().front() != addressable_devices[0]) { + auto ©_group = + copy_groups[std::make_pair(ifrt_sharding.devices()->devices().front(), + ifrt_sharding.memory_kind())]; + copy_group.indices.push_back(num_args_arrays.size()); + copy_group.arrays.push_back(tsl::FormRef(ifrt_array)); + num_args_arrays.push_back({}); + } else { + num_args_arrays.push_back(tsl::FormRef(ifrt_array)); + } + + keep_alive_objects.push_back(arg); + } + + if (!copy_groups.empty()) { + xla::ifrt::Client *const ifrt_client = + executable.ifrt_loaded_executable()->client(); + xla::ifrt::DeviceListRef ifrt_devices = + ifrt_client->MakeDeviceList({addressable_devices[0]}); + for (auto &[key, group] : copy_groups) { + TF_ASSIGN_OR_RETURN( + auto copied_ifrt_arrays, + ifrt_client->CopyArrays(absl::MakeSpan(group.arrays), ifrt_devices, + /*memory_kind=*/std::nullopt, + xla::ifrt::ArrayCopySemantics::kReuseInput)); + for (int i = 0; i < copied_ifrt_arrays.size(); ++i) { + num_args_arrays[group.indices[i]] = std::move(copied_ifrt_arrays[i]); + } + } + } + + return num_args_arrays; +} + +absl::StatusOr PjitFunction::Call(nb::handle callable, + PyObject *const *args, + size_t nargs, PyObject *kwnames) { + tsl::profiler::TraceMe traceme( + [&] { return absl::StrCat("PjitFunction(", function_name_, ")"); }); + + // Make sure we trigger a garbage collection on JIT function calls. Otherwise + // code like + // f = jit(...) + // while True: + // f(x) + // may never free temporary buffers for copies of arguments. + xla::GlobalPyRefManager()->MaybeCollectGarbage(); + + if (GetDisableJit()) { + if (!fun_.has_value()) { + throw nb::value_error( + absl::StrFormat("Disable jit is not supported in the AOT path since " + "the function is not available for (%s)", + function_name_) + .c_str()); + } + return nb::steal( + PyObject_Vectorcall(fun_.value().ptr(), args, nargs, kwnames)); + } + + // Calls the cache_miss_ function. This just calls the Python function; it may + // return nullptr value if a Python exception is thrown. + auto cache_miss = [&]() -> nb::tuple { + return nb::steal( + PyObject_Vectorcall(cache_miss_.ptr(), args, nargs, kwnames)); + }; + + // Call the cache_miss() function, extracting the output data and ignoring + // the fastpath data. If the cache miss returns a Python error, returns + // nullptr and leaves the Python error set. + auto fallback_to_cache_miss = [&]() { + nb::tuple cache_miss_output = cache_miss(); + if (!cache_miss_output.ptr()) { + return nb::object(); + } + return nb::object(cache_miss_output[0]); + }; + + size_t num_positional_args = PyVectorcall_NARGS(nargs); + size_t num_keyword_args = kwnames ? PyTuple_GET_SIZE(kwnames) : 0; + absl::Span positional_args(args, num_positional_args); + absl::Span keyword_args(args + num_positional_args, + num_keyword_args); + + CallSignature call_signature; + std::vector keep_alive_objects; + absl::InlinedVector flat_dynamic_args; + auto status = ParseArguments( + positional_args, keyword_args, kwnames, static_argnums_, static_argnames_, + pytree_registry_.get(), call_signature.arg_signature, flat_dynamic_args); + if (!status.ok()) { + VLOG(2) << "ParseArguments failed: " << status; + return fallback_to_cache_miss(); + } + + // Perform a few checks for the arguments. Currently we are only allowing + // committed PyArray inputs. For other cases, e.g. Tracers or ShapedArray, it + // will fallback to python. For jit, numpy arrays and scalars are also + // allowed, which we will check later. + for (const auto &arg : flat_dynamic_args) { + if (arg.type().ptr() != xla::PyArray::type().ptr()) { + continue; + } + + xla::PyArray py_array = nb::borrow(arg); + + // Only allow committed PyArray in cpp pjit for now as the logic on handling + // sharding for uncommitted PyArray is complicated and still under + // development. + // + // TODO(chky): Consider support uncommitted PyArray in cpp when the python + // side stablizes. + if (!py_array.committed() && + jax::Sharding::SafeNumDevices(py_array.sharding()) > 1) { + VLOG(2) << "PyArray argument is not committed and number of global " + "devices is more than 1; fallback to python."; + return fallback_to_cache_miss(); + } + } + + status = ComputeCallSignature(flat_dynamic_args, call_signature); + if (!status.ok()) { + VLOG(2) << "ComputeCallSignature failed: " << status; + return fallback_to_cache_miss(); + } + + VLOG(2) << "CallSignature:\n" << call_signature.DebugString(); + bool inserted = false; + std::shared_ptr cache_entry; + { + nb::ft_object_guard lock(cache_); + cache_entry = executables_->GetOrCreateIfAbsent( + call_signature, [this, &inserted](const CallSignature &unused) { + inserted = true; + return std::make_shared(pytree_registry_.get()); + }); + } + + if (!cache_entry->compilation_complete.HasBeenNotified()) { + // In case of several threads attempting to compile the executable, only + // the one that inserted the item will perform the compilation. + if (inserted) { + nb::object out_and_fastpath_data; + nb::tuple out_tuple; + VLOG(2) << "Cache miss for " << call_signature.DebugString(); + bool remove_cache = false; + try { + // Calls Python and may release the GIL. May also throw if + // compilation/tracing fails. + out_and_fastpath_data = cache_miss(); + if (!out_and_fastpath_data.ptr()) { + throw nb::python_error(); + } + out_tuple = nb::cast(out_and_fastpath_data); + + PopulateCacheEntry(*cache_entry, out_tuple); + + if (out_tuple.size() > 2 && out_tuple[2].is_valid()) { + remove_cache = nb::cast(out_tuple[2]); + } + } catch (const std::exception &e) { + VLOG(2) << "cache miss fail: " << e.what(); + cache_entry->fall_back_to_python = true; + cache_entry->compilation_complete.Notify(); + throw; + } + cache_entry->compilation_complete.Notify(); + + if (remove_cache) { + nb::ft_object_guard lock(cache_); + executables_->Remove(call_signature); + } + + // We have already computed the result in the miss path so we can return + // it. We are even *required* to do so if there are donated arguments, + // because any donated buffers will now be invalid. + return nb::object(out_tuple[0]); + } else { + if (cache_entry->thread_id == std::this_thread::get_id()) { + auto error_string = absl::StrCat("Recursively calling jit: ", + call_signature.DebugString()); + PyErr_SetString(PyExc_RecursionError, error_string.c_str()); + throw nb::python_error(); + } + // Release the GIL while we wait, making sure the compile thread can + // lock it. + nb::gil_scoped_release release; + cache_entry->compilation_complete.WaitForNotification(); + } + } + + if (cache_entry->fall_back_to_python) { + VLOG(2) << "cpp pjit fallback to python."; + return fallback_to_cache_miss(); + } + + // A vector of [num_inputs]. + auto num_args_arrays = PrepareIfrtInputs( + *cache_entry->executable, flat_dynamic_args, + call_signature.dynamic_arg_signatures, call_signature.jax_enable_x64, + cache_entry->kept_var_bitvec, cache_entry->in_shardings, + cache_entry->in_device_local_layouts, shard_arg_fallback_, + keep_alive_objects); + + if (!num_args_arrays.ok()) { + VLOG(2) << "Failed to prepare IFRT inputs: " << num_args_arrays.status(); + return fallback_to_cache_miss(); + } + + xla::ifrt::ExecuteOptions execute_options = + cache_entry->executable->options(); + execute_options.launch_id = cache_entry->executable->GetNextLaunchId(); + execute_options.execution_stream_id = + tsl::Env::Default()->GetCurrentThreadId(); + + // A vector of [num_outputs]. + std::vector output_arrays; + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(auto result, + cache_entry->executable->ifrt_executable()->Execute( + absl::MakeSpan(*num_args_arrays), execute_options, + /*devices=*/std::nullopt)); + output_arrays = std::move(result.outputs); + } + + auto traceback = xla::Traceback::Get(); + + // Convert the ifrt::Array objects to PyArray. + int num_outputs = output_arrays.size(); + absl::InlinedVector outputs; + outputs.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + // Creating the PyArray result. In addition to the IFRT arrays, the metadata + // like `aval` and `sharding` are retrieved from the cache for this + // function, which are produced by the python path in `cache_miss`. + xla::PyArray py_array( + cache_entry->out_avals[i], cache_entry->out_weak_types[i], + cache_entry->out_dtypes[i], cache_entry->out_shapes[i], + cache_entry->out_shardings[i], cache_entry->executable->client(), + traceback, std::move(output_arrays[i]), + /*committed=*/cache_entry->out_committed.at(i), /*skip_checks=*/true); + + outputs.push_back(std::move(py_array)); + } + + nb::object out = nb::steal( + cache_entry->out_pytree_def.Unflatten(outputs).release().ptr()); + + // If there is a post-hook function, call it with the inputs and the outputs. + std::optional post_hook = GetPostHook(); + if (post_hook) { + nb::tuple args_tuple = + nb::steal(PyTuple_New(num_positional_args)); + for (size_t i = 0; i < num_positional_args; ++i) { + Py_INCREF(args[i]); + PyTuple_SET_ITEM(args_tuple.ptr(), i, args[i]); + } + nb::dict kwargs; + if (kwnames) { + for (size_t i = 0; i < num_keyword_args; ++i) { + kwargs[nb::handle(PyTuple_GET_ITEM(kwnames, i))] = + nb::borrow(args[num_positional_args + i]); + } + } + (*post_hook)(nb::handle(callable.ptr()), args_tuple, kwargs, + nb::handle(out.ptr())); + } + + return out; +} + +absl::Status PjitFunction::ComputeCallSignature( + absl::Span flat_dynamic_args, CallSignature &signature) { + signature.function_name = function_name_; + + // Get dynamic argument signatures. + JitState &global_state = jax::GlobalJitState(); + JitState &tls = jax::ThreadLocalJitState(); + bool jax_enable_x64 = GetEnableX64(); + + signature.default_device = GetDefaultDevice(); + signature.jax_enable_x64 = jax_enable_x64; + + auto &dynamic_arg_signatures = signature.dynamic_arg_signatures; + dynamic_arg_signatures.reserve(flat_dynamic_args.size()); + auto &dynamic_arg_shardings = signature.dynamic_arg_shardings; + dynamic_arg_shardings.reserve(flat_dynamic_args.size()); + auto &dynamic_arg_layouts = signature.dynamic_arg_layouts; + dynamic_arg_layouts.reserve(flat_dynamic_args.size()); + + for (nb::handle arg : flat_dynamic_args) { + TF_ASSIGN_OR_RETURN(auto arg_signature, + xla::PyArgSignatureOfValue(arg, jax_enable_x64)); + signature.dynamic_arg_signatures.push_back(std::move(arg_signature)); + + // It should be already checked previously in the entry point of + // PjitFunction::Call(). + if (arg.type().ptr() == xla::PyArray::type().ptr()) { + auto py_array = nb::borrow(arg); + signature.dynamic_arg_shardings.push_back(py_array.sharding()); + auto layout = py_array.layout(); + if (absl::IsUnimplemented(layout.status())) { + signature.dynamic_arg_layouts.push_back(nullptr); + } else { + signature.dynamic_arg_layouts.push_back(*std::move(layout)); + } + signature.committed_args.push_back(py_array.committed()); + } else { + signature.dynamic_arg_shardings.push_back(nb::none()); + signature.dynamic_arg_layouts.push_back(nullptr); + signature.committed_args.push_back(false); + } + } + + signature.thread_local_extra_jit_context = tls.extra_jit_context; + signature.global_extra_jit_context = global_state.extra_jit_context; + signature.configs = JitConfigs(); + + return absl::OkStatus(); +} + +void PjitFunction::PopulateCacheEntry(PjitCacheEntry &cache_entry, + const nb::tuple &out_and_fastpath_data) { + DCHECK_GE(out_and_fastpath_data.size(), 2); + + if (out_and_fastpath_data[1].is_none()) { + VLOG(2) << "fastpath_data is none"; + cache_entry.fall_back_to_python = true; + return; + } + + nb::tuple fastpath_data = nb::cast(out_and_fastpath_data[1]); + + cache_entry.executable = nb::cast>( + fastpath_data.attr("xla_executable")); + + nb::sequence in_shardings = fastpath_data.attr("in_shardings"); + cache_entry.in_shardings.reserve(nb::len(in_shardings)); + for (nb::handle sharding : in_shardings) { + cache_entry.in_shardings.push_back(nb::borrow(sharding)); + } + + nb::sequence out_shardings = fastpath_data.attr("out_shardings"); + cache_entry.out_shardings.reserve(nb::len(out_shardings)); + for (nb::handle sharding : out_shardings) { + cache_entry.out_shardings.push_back(nb::borrow(sharding)); + } + + nb::sequence out_committed = fastpath_data.attr("out_committed"); + cache_entry.out_committed.reserve(nb::len(out_committed)); + for (nb::handle c : out_committed) { + cache_entry.out_committed.push_back(nb::cast(c)); + } + + nb::sequence out_avals = fastpath_data.attr("out_avals"); + cache_entry.out_avals.reserve(nb::len(out_avals)); + cache_entry.out_dtypes.reserve(nb::len(out_avals)); + cache_entry.out_shapes.reserve(nb::len(out_avals)); + cache_entry.out_weak_types.reserve(nb::len(out_avals)); + for (nb::handle aval : out_avals) { + cache_entry.out_avals.push_back(nb::borrow(aval)); + cache_entry.out_dtypes.push_back(aval.attr("dtype")); + cache_entry.out_shapes.push_back( + nb::cast>(aval.attr("shape"))); + cache_entry.out_weak_types.push_back( + nb::cast(aval.attr("weak_type"))); + } + + cache_entry.out_pytree_def = nb::cast( + nb::handle(fastpath_data.attr("out_pytree_def").ptr())); + + nb::sequence kept_var_bitvec = fastpath_data.attr("kept_var_bitvec"); + cache_entry.kept_var_bitvec.reserve(nb::len(kept_var_bitvec)); + for (nb::handle k : kept_var_bitvec) { + cache_entry.kept_var_bitvec.push_back(nb::cast(k)); + } + + nb::sequence in_device_local_layouts = + fastpath_data.attr("in_device_local_layouts"); + cache_entry.in_device_local_layouts.reserve(nb::len(in_device_local_layouts)); + for (nb::handle dll : in_device_local_layouts) { + cache_entry.in_device_local_layouts.push_back(nb::borrow(dll)); + } +} + +// Helper function used by the tp_clear GC method. +void PjitFunction::ClearPythonReferences() { + // TODO(mattjj): phawkins@ observed that the xla::PyTreeRegistry + // pytree_registry_ attribute of PjitFunction could in principle also have + // python references to clear + nb::callable cache_miss; + std::optional fun; + nb::callable shard_arg_fallback; + // Swap values for nulls before they are destroyed. See the Python + // Py_CLEAR() documentation for a discussion of this topic. + std::swap(cache_miss_, cache_miss); + std::swap(fun_, fun); + std::swap(shard_arg_fallback_, shard_arg_fallback); +} + +struct PjitFunctionObject { + PyObject_HEAD; +#if PY_VERSION_HEX < 0x030C0000 + PyObject *dict; // Dictionary for __dict__ + PyObject *weakrefs; // Weak references; for use by the Python interpreter. +#endif // PY_VERSION_HEX < 0x030C0000 + vectorcallfunc vectorcall; + PjitFunction fun; + + // Doubly-linked list of PjitFunctionObjects, protected by + // PjitFunctionStore::mu_ or the GIL in GIL mode. + PjitFunctionObject *next; + PjitFunctionObject *prev; +}; + +// Contains a list of all PjitFunctionObjects. +// Thread-safe. +class PjitFunctionStore { + public: + void Insert(PjitFunctionObject *o) { + nb::ft_lock_guard lock(mu_); + o->next = compiled_functions_; + o->prev = nullptr; + if (o->next) { + o->next->prev = o; + } + compiled_functions_ = o; + } + + void Remove(PjitFunctionObject *o) { + nb::ft_lock_guard lock(mu_); + if (o->next) { + o->next->prev = o->prev; + } + if (o->prev) { + o->prev->next = o->next; + } else { + compiled_functions_ = o->next; + } + } + + void ClearCaches() { + std::vector< + std::pair>> + caches; + { + nb::ft_lock_guard lock(mu_); + for (PjitFunctionObject *fn = compiled_functions_; fn != nullptr; + fn = fn->next) { + caches.emplace_back(fn->fun.cache(), fn->fun.executables()); + } + } + for (auto &[cache, executables] : caches) { + nb::ft_object_guard lock(cache); + executables->Clear(); + } + }; + + private: + // Protected by the GIL in GIL mode, and by mu_ in freethreading mode. + nb::ft_mutex mu_; + PjitFunctionObject *compiled_functions_; +}; + +PjitFunctionStore pjit_function_store; + +PyObject *PjitFunction_Type = nullptr; + +bool PjitFunction::IsPjitFunction(nb::handle handle) { + return handle.type().ptr() == PjitFunction_Type; +} + +PjitFunction *PjitFunction::AsPjitFunctionUnchecked(nb::handle handle) { + return &(reinterpret_cast(handle.ptr())->fun); +} + +PjitFunction *AsPjitFunction(nb::handle handle) { + if (!PjitFunction::IsPjitFunction(handle)) { + throw xla::XlaRuntimeError(xla::InvalidArgument("Expected a PjitFunction")); + } + return PjitFunction::AsPjitFunctionUnchecked(handle); +} + +extern "C" { + +PyObject *PjitFunction_tp_vectorcall(PyObject *callable, PyObject *const *args, + size_t nargs, PyObject *kwnames) { + PjitFunctionObject *o = reinterpret_cast(callable); + tsl::profiler::TraceMe traceme([&] { + return absl::StrCat("PjitFunction(", o->fun.function_name(), ")"); + }); + try { + absl::StatusOr out = + o->fun.Call(callable, args, nargs, kwnames); + if (!out.ok()) { + PyErr_SetString(PyExc_ValueError, out.status().ToString().c_str()); + return nullptr; + } + return out.value().release().ptr(); + } catch (nb::python_error &e) { + e.restore(); + return nullptr; + } catch (nb::cast_error &e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return nullptr; + } catch (std::invalid_argument &e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return nullptr; + } catch (std::runtime_error &e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return nullptr; + } +} + +PyObject *PjitFunction_tp_new(PyTypeObject *subtype, PyObject *args, + PyObject *kwds) { + PjitFunctionObject *self = + reinterpret_cast(subtype->tp_alloc(subtype, 0)); + if (!self) return nullptr; +#if PY_VERSION_HEX < 0x030C0000 + self->dict = nullptr; + self->weakrefs = nullptr; +#endif // PY_VERSION_HEX < 0x030C0000 + self->vectorcall = PjitFunction_tp_vectorcall; + return reinterpret_cast(self); +} + +void PjitFunction_tp_dealloc(PyObject *self) { + PyObject_GC_UnTrack(self); + PyTypeObject *tp = Py_TYPE(self); + PjitFunctionObject *o = reinterpret_cast(self); + pjit_function_store.Remove(o); + PyObject_ClearWeakRefs(self); +#if PY_VERSION_HEX < 0x030C0000 + Py_CLEAR(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + o->fun.~PjitFunction(); + tp->tp_free(self); + Py_DECREF(tp); +} + +int PjitFunction_tp_traverse(PyObject *self, visitproc visit, void *arg) { + // TODO(mattjj): phawkins@ observed that the xla::PyTreeRegistry + // pytree_registry_ attribute of PjitFunction could in principle also have + // python references to visit + PjitFunctionObject *o = reinterpret_cast(self); + // https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_traverse + Py_VISIT(Py_TYPE(self)); +#if PY_VERSION_HEX < 0x030C0000 + Py_VISIT(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_VisitManagedDict(self, visit, arg); +#else + PyObject_VisitManagedDict(self, visit, arg); +#endif // PY_VERSION_HEX < 0x030C0000 + Py_VISIT(o->fun.cache_miss().ptr()); + Py_VISIT(o->fun.shard_arg_fallback().ptr()); + if (o->fun.fun()) { + Py_VISIT(o->fun.fun()->ptr()); + } + return 0; +} + +int PjitFunction_tp_clear(PyObject *self) { + PjitFunctionObject *o = reinterpret_cast(self); +#if PY_VERSION_HEX < 0x030C0000 + Py_CLEAR(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + o->fun.ClearPythonReferences(); + return 0; +} + +// Implements the Python descriptor protocol so JIT-compiled functions can be +// used as bound methods. See: +// https://docs.python.org/3/howto/descriptor.html#functions-and-methods +PyObject *PjitFunction_tp_descr_get(PyObject *self, PyObject *obj, + PyObject *type) { + if (obj == nullptr || obj == Py_None) { + Py_INCREF(self); + return self; + } + return PyMethod_New(self, obj); +} + +static PyGetSetDef PjitFunction_tp_getset[] = { + // Having a __dict__ seems necessary to allow !functool.wraps to override + // __doc__. + {const_cast("__dict__"), PyObject_GenericGetDict, + PyObject_GenericSetDict, nullptr, nullptr}, + {nullptr, nullptr, nullptr, nullptr, nullptr}}; + +PyObject *PjitFunction_tp_repr(PyObject *self) { + try { + const std::string &repr = absl::StrFormat( + "", nb::cast(nb::repr( + nb::getattr(self, "__wrapped__")))); + return PyUnicode_FromString(repr.c_str()); + } catch (...) { + // Ignore all errors when accessing a repr. + return PyUnicode_FromString(""); + } +} + +} // extern "C" + +void InitializePjitFunction( + PjitFunctionObject *fn_obj, std::string function_name, + std::optional fun, nb::callable cache_miss, + std::vector static_argnums, std::vector static_argnames, + nb::object global_cache_key, + xla::nb_class_ptr pytree_registry, + nb::callable shard_arg_fallback, + xla::nb_class_ptr cache) { + fn_obj->next = fn_obj->prev = nullptr; + if (nb::isinstance(global_cache_key)) { + global_cache_key = nb::tuple(global_cache_key); + } + new (&fn_obj->fun) PjitFunction( + std::move(function_name), std::move(fun), std::move(cache_miss), + std::move(static_argnums), std::move(static_argnames), + std::move(global_cache_key), std::move(pytree_registry), + std::move(shard_arg_fallback), std::move(cache)); + // Handled separately because it is not exception safe to call this + // in the constructor because it leaves the object improperly constructed. + fn_obj->fun.InitExecutables(); + + // Only add the executable to the store after executables_ has been + // initialized. We want only fully constructed executables in the store. + pjit_function_store.Insert(fn_obj); +} + +nb::object MakePjitFunction( + std::string function_name, std::optional fun, + nb::callable cache_miss, std::vector static_argnums, + std::vector static_argnames, nb::object global_cache_key, + xla::nb_class_ptr pytree_registry, + nb::callable shard_arg_fallback, + std::optional> cache) { + nb::object obj = nb::steal(PjitFunction_tp_new( + reinterpret_cast(PjitFunction_Type), nullptr, nullptr)); + PjitFunctionObject *fn_obj = + reinterpret_cast(obj.ptr()); + if (!cache) { + cache = xla::make_nb_class( + PjitFunctionCache::kDefaultCapacity); + } + InitializePjitFunction( + fn_obj, std::move(function_name), std::move(fun), std::move(cache_miss), + std::move(static_argnums), std::move(static_argnames), + std::move(global_cache_key), std::move(pytree_registry), + std::move(shard_arg_fallback), std::move(*cache)); + return obj; +} + +// Version numbers for the pickled representations of +// PjitFunction. Increment these if changing them. +const int kPjitFunctionPickleVersion = 1; + +PyMemberDef PjitFunction_members[] = { + {"__vectorcalloffset__", T_PYSSIZET, + static_cast(offsetof(PjitFunctionObject, vectorcall)), + READONLY, nullptr}, +#if PY_VERSION_HEX < 0x030C0000 + {"__dictoffset__", T_PYSSIZET, + static_cast(offsetof(PjitFunctionObject, dict)), READONLY, + nullptr}, + {"__weaklistoffset__", T_PYSSIZET, + static_cast(offsetof(PjitFunctionObject, weakrefs)), READONLY, + nullptr}, +#endif // PY_VERSION_HEX < 0x030C0000 + {nullptr, 0, 0, 0, nullptr}, +}; + +PyType_Slot PjitFunction_slots[] = { + {Py_tp_new, reinterpret_cast(PjitFunction_tp_new)}, + {Py_tp_dealloc, reinterpret_cast(PjitFunction_tp_dealloc)}, + {Py_tp_traverse, reinterpret_cast(PjitFunction_tp_traverse)}, + {Py_tp_clear, reinterpret_cast(PjitFunction_tp_clear)}, + {Py_tp_getset, reinterpret_cast(PjitFunction_tp_getset)}, + {Py_tp_descr_get, reinterpret_cast(PjitFunction_tp_descr_get)}, + {Py_tp_call, reinterpret_cast(PyVectorcall_Call)}, + {Py_tp_repr, reinterpret_cast(PjitFunction_tp_repr)}, + {Py_tp_members, reinterpret_cast(PjitFunction_members)}, + {0, nullptr}, +}; + +} // namespace + +void BuildPjitSubmodule(nb::module_ &m) { + nb::class_ cache(m, "PjitFunctionCache"); + cache.def(nb::init(), + nb::arg("capacity") = PjitFunctionCache::kDefaultCapacity); + cache.def("size", &PjitFunctionCache::Size, nb::lock_self()); + cache.def("capacity", &PjitFunctionCache::Capacity, nb::lock_self()); + cache.def("clear", &PjitFunctionCache::Clear, nb::lock_self()); + cache.def_static("clear_all", []() { pjit_function_store.ClearCaches(); }); + cache.def( + "__getstate__", + // Pickles as an empty cache; the client can repopulate as needed. + [](const PjitFunctionCache &cache) { + nb::dict pickle; + pickle["version"] = kPjitFunctionPickleVersion; + pickle["capacity"] = cache.Capacity(); + return pickle; + }, + nb::lock_self()); + cache.def("__setstate__", + [](PjitFunctionCache *cache, const nb::dict &pickle) { + int version = nb::cast(pickle["version"]); + if (version != kPjitFunctionPickleVersion) { + throw std::invalid_argument(absl::StrFormat( + "Invalid PjitFunction pickle version, got %d, expected %d", + version, kPjitFunctionPickleVersion)); + } + int capacity = nb::cast(pickle["capacity"]); + new (cache) PjitFunctionCache(capacity); + }); + + // We need to use heap-allocated type objects because we want to add + // additional methods dynamically. + std::string name = + absl::StrCat(nb::cast(m.attr("__name__")), ".PjitFunction"); + PyType_Spec PjitFunction_spec = { +#if PY_VERSION_HEX < 0x030B0000 + // Work around for https://github.com/python/cpython/issues/89478 + // CPython 3.10 and earlier assume that the .name value remains alive + // forever. + /*.name=*/strdup(name.c_str()), +#else + /*.name=*/name.c_str(), +#endif // PY_VERSION_HEX < 0x030B0000 + /*.basicsize=*/static_cast(sizeof(PjitFunctionObject)), + /*.itemsize=*/0, +#if PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_HAVE_VECTORCALL, +#else // PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_HAVE_VECTORCALL | Py_TPFLAGS_MANAGED_DICT | + Py_TPFLAGS_MANAGED_WEAKREF, +#endif // PY_VERSION_HEX < 0x030C0000 + /*.slots=*/PjitFunction_slots, + }; + PjitFunction_Type = PyType_FromSpec(&PjitFunction_spec); + if (!PjitFunction_Type) { + throw nb::python_error(); + } + nb::object cfun = nb::borrow(PjitFunction_Type); + + // Add PjitFunction to the _jax module so it can be pickled. + m.attr("PjitFunction") = cfun; + cfun.attr("__getstate__") = nb::cpp_function( + [](const PjitFunction::object &self) { + PjitFunction *fn = self.func(); + nb::dict pickle; + pickle["version"] = kPjitFunctionPickleVersion; + pickle["function_name"] = fn->function_name(); + if (fn->fun().has_value()) { + pickle["fun"] = *fn->fun(); + } + pickle["cache_miss"] = fn->cache_miss(); + pickle["static_argnums"] = fn->static_argnums(); + pickle["static_argnames"] = nb::cast(fn->static_argnames()); + pickle["global_cache_key"] = fn->global_cache_key(); + pickle["pytree_registry"] = nb::cast(fn->pytree_registry()); + pickle["shard_arg_fallback"] = fn->shard_arg_fallback(); + pickle["cache"] = fn->cache(); + return pickle; + }, + nb::is_method()); + cfun.attr("__setstate__") = nb::cpp_function( + [](nb::object &self, const nb::dict &pickle) { + int version = nb::cast(pickle["version"]); + if (version != kPjitFunctionPickleVersion) { + throw std::invalid_argument(absl::StrFormat( + "Invalid PjitFunction pickle version, got %d, expected %d. " + "Pickling/Unpickling jitted functions using different JAX " + "versions is not supported.", + version, kPjitFunctionPickleVersion)); + } + std::string function_name = + nb::cast(pickle["function_name"]); + std::optional fun; + if (pickle.contains("fun")) { + fun = nb::cast(pickle["fun"]); + } + nb::callable cache_miss = nb::cast(pickle["cache_miss"]); + std::vector static_argnums = + nb::cast>(pickle["static_argnums"]); + std::vector static_argnames = + nb::cast>(pickle["static_argnames"]); + nb::object global_cache_key = pickle["global_cache_key"]; + xla::nb_class_ptr pytree_registry = + nb::cast>( + nb::handle(pickle["pytree_registry"].ptr())); + nb::callable shard_arg_fallback = + nb::cast(pickle["shard_arg_fallback"]); + xla::nb_class_ptr cache = + nb::cast>(pickle["cache"]); + InitializePjitFunction( + reinterpret_cast(self.ptr()), + std::move(function_name), std::move(fun), std::move(cache_miss), + std::move(static_argnums), std::move(static_argnames), + std::move(global_cache_key), std::move(pytree_registry), + std::move(shard_arg_fallback), std::move(cache)); + }, + nb::is_method()); + cfun.attr("__signature__") = + xla::nb_property_readonly([](nb::handle self) -> nb::object { + return AsPjitFunction(self)->PythonSignature(); + }); + cfun.attr("_cache_miss") = + xla::nb_property_readonly([](nb::handle self) -> nb::object { + return AsPjitFunction(self)->cache_miss(); + }); + // All private members are only for testing/debugging purposes + cfun.attr("_cache_size") = nb::cpp_function( + [](nb::handle self) -> int { + return AsPjitFunction(self)->cache_capacity(); + }, + nb::is_method()); + cfun.attr("_clear_cache") = nb::cpp_function( + [](nb::handle self) { AsPjitFunction(self)->ClearCache(); }, + nb::is_method()); + + m.def( + "pjit", + [](std::string function_name, std::optional fun, + nb::callable cache_miss, std::vector static_argnums, + std::vector static_argnames, nb::object global_cache_key, + nb::object pytree_registry, nb::callable shard_arg_fallback, + std::optional> cache) { + xla::nb_class_ptr registry = + nb::cast>( + nb::handle(pytree_registry.ptr())); + return MakePjitFunction( + std::move(function_name), std::move(fun), std::move(cache_miss), + std::move(static_argnums), std::move(static_argnames), + std::move(global_cache_key), std::move(registry), + std::move(shard_arg_fallback), std::move(cache)); + }, + nb::arg("function_name"), nb::arg("fun").none(), nb::arg("cache_miss"), + nb::arg("static_argnums"), nb::arg("static_argnames"), + nb::arg("global_cache_key"), nb::arg("pytree_registry"), + nb::arg("shard_arg_fallback"), nb::arg("cache").none() = nb::none()); +} + +} // namespace jax diff --git a/tests/ci_clangformat/pjit.h b/tests/ci_clangformat/pjit.h new file mode 100644 index 0000000..c782526 --- /dev/null +++ b/tests/ci_clangformat/pjit.h @@ -0,0 +1,27 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PJIT_H_ +#define JAXLIB_PJIT_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace jax { + +void BuildPjitSubmodule(nanobind::module_ &m); +} + +#endif // JAXLIB_PJIT_H_ diff --git a/tests/ci_clangformat/pmap_lib.cc b/tests/ci_clangformat/pmap_lib.cc new file mode 100644 index 0000000..4d2b727 --- /dev/null +++ b/tests/ci_clangformat/pmap_lib.cc @@ -0,0 +1,1141 @@ +/* Copyright 2021 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/pmap_lib.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/synchronization/notification.h" +#include "absl/types/span.h" +#include "jaxlib/config.h" +#include "jaxlib/jax_jit.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_device.h" +#include "jaxlib/py_executable.h" +#include "jaxlib/py_values.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/pytree.h" +#include "jaxlib/sharded_device_array.h" +#include "jaxlib/sharding.h" +#include "jaxlib/traceback.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "tsl/profiler/lib/traceme.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_helpers.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/types.h" +#include "xla/status_macros.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/python/lib/core/numpy.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace jax { + +namespace nb = nanobind; + +namespace { + +// Specifies how to shard the inputs. Even though everything could be computed +// from `sharding_specs` and the argument shape, we cache derived computations +// for performance. +struct InputSpec { + InputSpec(nb::object indices, nb::object array_sharding) + : indices(std::move(indices)), + array_sharding(std::move(array_sharding)) {} + nb::object indices; + nb::object array_sharding; +}; + +// An object containing the arguments to create Array from the +// output buffers. +struct ResultSpec { + public: + explicit ResultSpec(nb::object aval) + : out_aval(std::move(aval)), + weak_type(nb::cast(out_aval.attr("weak_type"))) {} + nb::object out_aval; + bool weak_type; +}; + +// The result of `ShardArg`. +struct ShardArgResult { + // Points to the on-device array. + // ifrt_array->sharding().num_shards() == `num_devices`. + xla::ifrt::ArrayRef ifrt_array; + // The Python argument will be always be copied to `owning_sda`. + nb::object owning_sda; +}; + +// Shards a single argument over devices. +// +// We currently only support fully in C++, C++ Array. For all +// other usages, we call a Python function returning C++ Array +// that will be casted back to the C++ objects. +// +// This function is not usable for JAX extensions that do not comply with the +// PjRt interfaces. +// +// Arguments: +// `arg`: The object to shard across `devices`. If a `Array`, +// a fast-path will be executed if it's already correctly sharded. +// +// Returns a failure absl::Status when an unrecoverable error occurred, so we +// don't need to fallback to Python. +// +// Both `devices` and `sharding_spec` has the same length. +absl::StatusOr ShardArg( + nb::handle arg, absl::Span devices, + const InputSpec &input_spec, nb::handle py_devices, + const nb::callable &python_fallback) { + if (arg.type().ptr() == xla::PyArray::type().ptr()) { + auto py_array = nb::borrow(arg); + if (py_array.sharding().type().ptr() == + input_spec.array_sharding.type().ptr()) { + auto *pmap_sharding = nb::cast(py_array.sharding()); + auto *cached_pmap_sharding = + nb::cast(input_spec.array_sharding); + + if (pmap_sharding->sharding_spec() == + cached_pmap_sharding->sharding_spec()) { + ShardArgResult result; + result.owning_sda = nb::borrow(arg); + result.ifrt_array = tsl::FormRef(py_array.ifrt_array()); + if (result.ifrt_array == nullptr) { + return xla::InvalidArgument("Array has been deleted."); + } + if (result.ifrt_array->sharding().devices()->devices() != devices) { + absl::InlinedVector ifrt_devices; + ifrt_devices.reserve(devices.size()); + ifrt_devices.insert(ifrt_devices.end(), devices.begin(), + devices.end()); + // pmap does not support memory_kind for now. + auto *ifrt_client = result.ifrt_array->client(); + TF_ASSIGN_OR_RETURN(auto copied_ifrt_arrays, + ifrt_client->CopyArrays( + absl::MakeSpan(&result.ifrt_array, 1), + ifrt_client->MakeDeviceList(ifrt_devices), + xla::ifrt::MemoryKind(), + xla::ifrt::ArrayCopySemantics::kReuseInput)); + result.ifrt_array = std::move(copied_ifrt_arrays.front()); + } + return result; + } + } + } + + auto ndarray = xla::nb_numpy_ndarray::ensure(arg); + if (ndarray && PyArray_CheckExact(arg.ptr()) && + xla::DtypeToPrimitiveType(ndarray.dtype()).status().ok()) { + tsl::profiler::TraceMe traceme("ndarray pmap ShardArg"); + nb::list indices = nb::list(input_spec.indices); + nb::list py_devices_list = nb::cast(py_devices); + auto n_devices = py_devices_list.size(); + if (indices.size() != n_devices) { + return xla::InvalidArgument("indices vs devices mismatch: %d vs %d", + indices.size(), n_devices); + } + + ShardArgResult result; + const bool jax_enable_x64 = GetEnableX64(); + + std::vector owning_args; + std::vector args; + owning_args.reserve(n_devices); + args.reserve(n_devices); + xla::DevicePutOptions options; + options.squash_64bit_types = !jax_enable_x64; + options.allow_zero_copy = true; + xla::ifrt::Client *ifrt_client = nullptr; + for (size_t i = 0; i < n_devices; ++i) { + auto to_device = nb::cast(py_devices_list[i]); + if (to_device->client().get() == nullptr) { + return xla::InvalidArgument("Cannot copy to unattached devices."); + } + if (i == 0) { + ifrt_client = to_device->client()->ifrt_client(); + } + owning_args.push_back(arg[indices[i]]); + args.push_back(owning_args.back()); + } + CHECK(ifrt_client != nullptr); + TF_ASSIGN_OR_RETURN( + xla::DevicePutResult device_put_result, + xla::DevicePutWithSharding( + args, ifrt_client, ndarray.dtype(), + nb::cast>(ndarray.attr("shape")), + input_spec.array_sharding, options)); + result.ifrt_array = std::move(device_put_result.ifrt_array); + return result; + } + tsl::profiler::TraceMe traceme("pmap_lib_shard_arg_python_fallback"); + auto py_array_or_bufs = python_fallback(arg, input_spec.array_sharding); + + auto py_array = nb::cast(py_array_or_bufs); + ShardArgResult result; + result.owning_sda = nb::borrow(py_array_or_bufs); + result.ifrt_array = tsl::FormRef(py_array.ifrt_array()); + return result; +} + +struct PmapCacheEntry { + explicit PmapCacheEntry(xla::PyTreeRegistry *registry) + : out_pytree_def(registry) {} + std::shared_ptr executable; + // The value `backend.local_devices()`. + nb::object py_devices; // To pass back to Python. + std::vector devices; + std::vector input_specs; + xla::PyTreeDef out_pytree_def; + // Objects necessary to build the out Array objects. + std::vector out_result_specs; + + std::vector out_array_shardings; + std::vector out_dtypes; + std::vector> out_shapes; + std::vector out_committed; + + // Ensures a single thread performs the compilation for a given executable. + // + // The first thread (holding the GIL) will create the CacheEntry associated to + // a signature and if the object has been inserted already, other threads + // will wait for the notification. + absl::Notification compilation_complete; + + bool fall_back_to_python = false; +}; + +} // namespace + +// A `PmapFunction` is associated to a `jax.pmap(f)` and takes care of the +// bookkeeping of the different signatures used and the dispatch of calls to +// the correct underlying `PyLoadedExecutable`. This class is thread-safe. +class PmapFunction { + public: + PmapFunction(nb::callable fun, nb::callable cache_miss, + std::vector static_argnums, + nb::callable python_shard_arg_fallback, + xla::nb_class_ptr pytree_registry) + : fun_(std::move(fun)), + cache_miss_(std::move(cache_miss)), + static_argnums_(std::move(static_argnums)), + pytree_registry_(std::move(pytree_registry)), + python_shard_arg_fallback_(std::move(python_shard_arg_fallback)) { + std::sort(static_argnums_.begin(), static_argnums_.end()); + + function_name_ = + nb::cast(nb::str(nb::getattr(fun_, "__name__", fun_))); + } + PmapFunction(const PmapFunction &) = delete; + PmapFunction &operator=(const PmapFunction &other) = delete; + PmapFunction(PmapFunction &&) = default; + PmapFunction &operator=(PmapFunction &&) = default; + + // This function will: + // (a) flatten the inputs using pytree + // (b) get buffer objects from the arguments + // (c) call the executable + // (d) construct `Array` objects from the outputs + // (e) reconstruct the `PyTree`. + absl::StatusOr Call(nb::handle callable, PyObject *const *args, + size_t nargs, PyObject *kwnames); + + nb::object PythonSignature() { + static const auto *inspect = + new nb::module_(nb::module_::import_("inspect")); + return inspect->attr("signature")(fun_); + } + + int cache_size() { + nb::ft_lock_guard lock(mu_); + return executables_.size(); + } + void cache_clear() { + nb::ft_lock_guard lock(mu_); + return executables_.clear(); + } + const nb::callable &fun() const { return fun_; } + const nb::callable &cache_miss() const { return cache_miss_; } + const std::string &function_name() const { return function_name_; } + const xla::nb_class_ptr &pytree_registry() const { + return pytree_registry_; + } + const nb::callable &python_shard_arg_fallback() const { + return python_shard_arg_fallback_; + } + const std::vector &static_argnums() const { return static_argnums_; } + + // nb::object typed subclass for PmapFunction objects. + class pyobject : public nb::object { + public: + NB_OBJECT(pyobject, nb::object, "PmapFunction", + PmapFunction::IsPmapFunction); + pyobject() = default; + PmapFunction *func() const { + return PmapFunction::AsPmapFunctionUnchecked(*this); + } + }; + // Alias as ::object; outside the scope above we won't confuse nanobind's + // macros. + using object = pyobject; + + // Returns true if `h` is a PmapFunction. + static bool IsPmapFunction(nb::handle handle); + // Converts `handle` to a PmapFunction*. Does not do any checking. + static PmapFunction *AsPmapFunctionUnchecked(nb::handle handle); + + // Helper function used by the tp_clear GC method. + void ClearPythonReferences() { + nb::callable fun, cache_miss, python_shard_arg_fallback; + // Swap values for nulls before they are destroyed. See the Python + // Py_CLEAR() documentation for a discussion of this topic. + std::swap(fun_, fun); + std::swap(cache_miss_, cache_miss); + std::swap(python_shard_arg_fallback_, python_shard_arg_fallback); + } + + // Updates the signature of arguments for a pmapped function. + // + // It deals with the arguments signatures and also of the global and + // thread-local jit context. + absl::Status ComputeCallSignature( + absl::Span flat_dynamic_args, + CallSignature &signature) { + signature.function_name = function_name_; + + // Get dynamic argument signatures. + JitState &global_state = jax::GlobalJitState(); + JitState &tls = jax::ThreadLocalJitState(); + const bool jax_enable_x64 = GetEnableX64(); + signature.jax_enable_x64 = jax_enable_x64; + for (nb::handle arg : flat_dynamic_args) { + auto signature_or_error = xla::PyArgSignatureOfValue(arg, jax_enable_x64); + if (!signature_or_error.ok()) { + VLOG(2) << "PyArgSignatureOfValue failed: " + << signature_or_error.status(); + return signature_or_error.status(); + } + signature.dynamic_arg_signatures.push_back( + std::move(signature_or_error).value()); + } + signature.thread_local_extra_jit_context = tls.extra_jit_context; + signature.global_extra_jit_context = global_state.extra_jit_context; + signature.configs = JitConfigs(); + return absl::Status(); + } + + // Returns, for debugging purposes (e.g. finding why some call misses the + // cache and recompiles), the list of the string representations of the keys. + // + // The format can change at any time. + std::string DebugCacheKeys() { + nb::ft_lock_guard lock(mu_); + std::vector key_strings = { + absl::StrCat("The cache contains ", executables_.size(), " elements:")}; + // We will be able to use auto& [key, _] when TF uses C++ 17. + for (auto &pair : executables_) { + key_strings.push_back(pair.first.DebugString()); + } + return absl::StrJoin(key_strings, "\n\n"); + } + + private: + // Mutates `cache_entry` in place. + void PopulateCacheEntry(PmapCacheEntry &cache_entry, + const nb::tuple &out_and_fastpath_data); + + bool always_fallback_to_python_ = false; + + nb::callable fun_; // The Python function to pmap. + std::string function_name_; + // See JAX _cpp_pmap in api.py for documentation. + nb::callable cache_miss_; + + // We need to know the static arguments to remove them from the arguments + // passed to the underlying PyLoadedExecutable. In sorted order. + std::vector static_argnums_; + xla::nb_class_ptr pytree_registry_; + // We need a `shared_ptr` here to ensure value pointer stability, and to + // ensure that the cache entry remains alive in the presence of concurrent + // removals. + absl::flat_hash_map> + executables_; + + // The fallback function to use with `ShardArgs`. + // TODO(jblespiau): Add support for more types from C++. + nb::callable python_shard_arg_fallback_; + + // Protect methods in FT: + nb::ft_mutex mu_; +}; + +void PmapFunction::PopulateCacheEntry(PmapCacheEntry &cache_entry, + const nb::tuple &out_and_fastpath_data) { + CHECK_EQ(out_and_fastpath_data.size(), 2); + if (out_and_fastpath_data[1].is_none()) { + cache_entry.fall_back_to_python = true; + return; + } + + nb::tuple pmap_data = nb::cast(out_and_fastpath_data[1]); + if (nb::cast(pmap_data.attr("version")) != 1) { + throw xla::XlaRuntimeError(absl::StrCat( + "The versions of jaxlib and Jax are incompatible (pmap cpp version 1 " + "expected, but got ", + nb::cast(pmap_data.attr("version")), + "Upgrade jaxlib and jax. Provided data was:", + nb::cast(nb::str(nb::repr(pmap_data))))); + } + // See api.nb::_PmapFastpathData in the JAX code base for the expected + // namedtuple. + std::shared_ptr executable; + try { + executable = nb::cast>( + pmap_data.attr("xla_executable")); + } catch (const nb::cast_error &e) { + // Backends that don't implement the C++ PjRt APIs + cache_entry.fall_back_to_python = true; + always_fallback_to_python_ = true; + return; + } + cache_entry.executable = std::move(executable); + const std::vector> &devices = + cache_entry.executable->AddressableDevices(); + cache_entry.devices.reserve(devices.size()); + for (auto &device : devices) { + cache_entry.devices.push_back(device->device()); + } + + // Inputs shard args details. + nb::list input_indices = pmap_data.attr("input_indices"); + + cache_entry.py_devices = pmap_data.attr("input_devices"); + auto input_devices = nb::cast>>( + pmap_data.attr("input_devices")); + + nb::list input_array_shardings = pmap_data.attr("input_array_shardings"); + + cache_entry.input_specs.reserve(input_array_shardings.size()); + + for (int i = 0; i < input_array_shardings.size(); ++i) { + cache_entry.input_specs.emplace_back(input_indices[i], + input_array_shardings[i]); + } + + // Outputs specs. + auto out_tree = nb::cast(pmap_data.attr("out_pytree_def")); + cache_entry.out_pytree_def = std::move(out_tree); + nb::list out_avals = pmap_data.attr("out_avals"); + + cache_entry.out_result_specs.reserve(out_avals.size()); + cache_entry.out_dtypes.reserve(out_avals.size()); + cache_entry.out_shapes.reserve(out_avals.size()); + + for (int i = 0; i < out_avals.size(); ++i) { + cache_entry.out_dtypes.push_back(out_avals[i].attr("dtype")); + cache_entry.out_shapes.push_back( + nb::cast>(out_avals[i].attr("shape"))); + cache_entry.out_result_specs.emplace_back(out_avals[i]); + } + + nb::list out_array_shardings = pmap_data.attr("out_array_shardings"); + + DCHECK(out_array_shardings.size() == 0 || + out_avals.size() == out_array_shardings.size()); + + cache_entry.out_array_shardings.reserve(out_array_shardings.size()); + for (nb::handle out_array_sharding : out_array_shardings) { + cache_entry.out_array_shardings.push_back( + nb::borrow(out_array_sharding)); + } + + nb::list out_committed = pmap_data.attr("out_committed"); + + DCHECK(out_committed.size() == 0 || out_avals.size() == out_committed.size()); + + cache_entry.out_committed.reserve(out_committed.size()); + for (nb::handle c : out_committed) { + cache_entry.out_committed.push_back(nb::cast(c)); + } +} + +absl::StatusOr PmapFunction::Call(nb::handle callable, + PyObject *const *args, + size_t nargs, PyObject *kwnames) { + xla::GlobalPyRefManager()->MaybeCollectGarbage(); + + // Calls the cache_miss_ function. This just calls the Python function; it may + // return nullptr value if a Python exception is thrown. + auto cache_miss = [&]() -> nb::tuple { + return nb::steal( + PyObject_Vectorcall(cache_miss_.ptr(), args, nargs, kwnames)); + }; + + // Call the cache_miss() function, extracting the output data and ignoring + // the fastpath data. If the cache miss returns a Python error, returns + // nullptr and leaves the Python error set. + auto fallback_to_cache_miss = [&]() { + nb::tuple cache_miss_output = cache_miss(); + if (!cache_miss_output.ptr()) { + return nb::object(); + } + return nb::object(cache_miss_output[0]); + }; + + if (always_fallback_to_python_) { + return fallback_to_cache_miss(); + } + + size_t num_positional_args = PyVectorcall_NARGS(nargs); + size_t num_keyword_args = kwnames ? PyTuple_GET_SIZE(kwnames) : 0; + absl::Span positional_args(args, num_positional_args); + absl::Span keyword_args(args + num_positional_args, + num_keyword_args); + CallSignature call_signature; + absl::InlinedVector flat_dynamic_args; + std::vector keep_alive_objects; + absl::Status status = + ParseArguments(positional_args, keyword_args, kwnames, static_argnums_, + /*static_argnames=*/{}, pytree_registry_.get(), + call_signature.arg_signature, flat_dynamic_args); + if (!status.ok()) { + VLOG(2) << "ParseArguments failed: " << status; + return fallback_to_cache_miss(); + } + + status = ComputeCallSignature(flat_dynamic_args, call_signature); + if (!status.ok()) { + return fallback_to_cache_miss(); + } + + // Retrieve/Maybe add the executable to the cache. + bool inserted = false; + std::shared_ptr cache_entry_ptr; + { + nb::ft_lock_guard lock(mu_); + std::shared_ptr &entry_ref = executables_[call_signature]; + if (!entry_ref) { + inserted = true; + entry_ref = std::make_shared(pytree_registry_.get()); + } + cache_entry_ptr = entry_ref; + } + PmapCacheEntry &cache_entry = *cache_entry_ptr; + + if (!cache_entry.compilation_complete.HasBeenNotified()) { + // In case of several threads attempting to compile the executable, only + // the one that inserted the item will perform the compilation. + if (inserted) { + nb::object out_and_fastpath_data; + nb::tuple out_tuple; + VLOG(2) << "Cache miss for " << call_signature.DebugString(); + try { + // Calls Python and may release the GIL. May also throw if + // compilation/tracing fails. + out_and_fastpath_data = cache_miss(); + if (!out_and_fastpath_data.ptr()) { + throw nb::python_error(); + } + out_tuple = nb::cast(out_and_fastpath_data); + + PopulateCacheEntry(cache_entry, out_tuple); + } catch (const std::exception &e) { + cache_entry.fall_back_to_python = true; + cache_entry.compilation_complete.Notify(); + throw; + } + cache_entry.compilation_complete.Notify(); + + // We have already computed the result in the miss path so we can return + // it. We are even *required* to do so if there are donated arguments, + // because any donated buffers will now be invalid. + return nb::object(out_tuple[0]); + } else { + // Release the GIL while we wait, making sure the compile thread can + // lock it. + nb::gil_scoped_release release; + cache_entry.compilation_complete.WaitForNotification(); + } + } + if (cache_entry.fall_back_to_python) { + return fallback_to_cache_miss(); + } + + // 1. Parse arguments. + std::vector &input_devices = cache_entry.devices; + std::vector &input_specs = cache_entry.input_specs; + const int num_args = flat_dynamic_args.size(); + + // We need [num_args] for the `Execute` call below. + std::vector num_args_arrays(num_args); + for (int i = 0; i < num_args; ++i) { + TF_ASSIGN_OR_RETURN( + ShardArgResult sharded_arg, + ShardArg(flat_dynamic_args[i], input_devices, input_specs[i], + cache_entry.py_devices, python_shard_arg_fallback_)); + + num_args_arrays[i] = std::move(sharded_arg.ifrt_array); + if (sharded_arg.owning_sda) { + keep_alive_objects.push_back(std::move(sharded_arg.owning_sda)); + } + } + + xla::ifrt::ExecuteOptions execute_options = cache_entry.executable->options(); + execute_options.launch_id = cache_entry.executable->GetNextLaunchId(); + execute_options.execution_stream_id = + tsl::Env::Default()->GetCurrentThreadId(); + + // A vector of [num_outputs]. + std::vector output_arrays; + { + nb::gil_scoped_release gil_release; + auto ifrt_executable = cache_entry.executable->ifrt_executable(); + TF_ASSIGN_OR_RETURN( + auto result, ifrt_executable->Execute(absl::MakeSpan(num_args_arrays), + execute_options, + /*devices=*/std::nullopt)); + output_arrays = std::move(result.outputs); + } + + // TODO(jblespiau): We don't need to create the PyBuffer objects. + // Having a C++ `Array`, keeping internally the PjRtBuffer + // objects is sufficient, and we can lazily create the `PyBuffer` only if + // we access them from Python. + auto traceback = xla::Traceback::Get(); + // TODO(jblespiau): Change the `client` function to return a reference. + xla::nb_class_ptr client = cache_entry.executable->client(); + + // Convert the PjRtBuffer objects to PyBuffer, and invert the order from + // [num_devices, num_args] to [num_args, num_devices]. + const int num_outputs = output_arrays.size(); + std::vector flat_sharded_device_arrays; + flat_sharded_device_arrays.reserve(num_outputs); + + const auto &output_specs = cache_entry.out_result_specs; + + TF_RET_CHECK(cache_entry.out_array_shardings.size() == num_outputs); + for (int i = 0; i < num_outputs; ++i) { + const ResultSpec &result_spec = output_specs[i]; + xla::PyArray py_array( + result_spec.out_aval, result_spec.weak_type, cache_entry.out_dtypes[i], + cache_entry.out_shapes[i], cache_entry.out_array_shardings[i], client, + traceback, std::move(output_arrays[i]), cache_entry.out_committed[i], + /*skip_checks=*/true); + + flat_sharded_device_arrays.push_back(std::move(py_array)); + } + + nb::object out = + cache_entry.out_pytree_def.Unflatten(flat_sharded_device_arrays); + + // If there is a post-hook function, call it with the inputs and the outputs. + std::optional post_hook = GetPostHook(); + if (post_hook) { + nb::tuple args_tuple = + nb::steal(PyTuple_New(num_positional_args)); + for (size_t i = 0; i < num_positional_args; ++i) { + Py_INCREF(args[i]); + PyTuple_SET_ITEM(args_tuple.ptr(), i, args[i]); + } + nb::dict kwargs; + if (kwnames) { + for (size_t i = 0; i < num_keyword_args; ++i) { + kwargs[nb::handle(PyTuple_GET_ITEM(kwnames, i))] = + nb::borrow(args[num_positional_args + i]); + } + } + + (*post_hook)(callable, args_tuple, kwargs, out); + } + + return out; +} + +struct JaxPmapFunctionObject { + PyObject_HEAD; +#if PY_VERSION_HEX < 0x030C0000 + PyObject *dict; // Dictionary for __dict__ + PyObject *weakrefs; // Weak references; for use by the Python interpreter. +#endif // PY_VERSION_HEX < 0x030C0000 + vectorcallfunc vectorcall; + PmapFunction fun; +}; + +PyObject *JaxPmapFunction_Type = nullptr; + +bool PmapFunction::IsPmapFunction(nb::handle handle) { + return handle.type().ptr() == JaxPmapFunction_Type; +} + +PmapFunction *PmapFunction::AsPmapFunctionUnchecked(nb::handle handle) { + return &(reinterpret_cast(handle.ptr())->fun); +} + +absl::StatusOr AsPmapFunction(nb::handle handle) { + if (!PmapFunction::IsPmapFunction(handle)) { + return xla::InvalidArgument("Expected a PmapFunction"); + } + return PmapFunction::AsPmapFunctionUnchecked(handle); +} + +namespace { + +extern "C" { + +PyObject *JaxPmapFunction_tp_vectorcall(PyObject *callable, + PyObject *const *args, size_t nargs, + PyObject *kwnames) { + JaxPmapFunctionObject *o = + reinterpret_cast(callable); + tsl::profiler::TraceMe traceme([&] { + return absl::StrCat("JaxPmapFunction(", o->fun.function_name(), ")"); + }); + try { + absl::StatusOr out = + o->fun.Call(callable, args, nargs, kwnames); + if (!out.ok()) { + PyErr_SetString(PyExc_ValueError, out.status().ToString().c_str()); + return nullptr; + } + return out.value().release().ptr(); + } catch (nb::python_error &e) { + e.restore(); + return nullptr; + } catch (nb::cast_error &e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return nullptr; + } catch (std::invalid_argument &e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return nullptr; + } +} + +PyObject *JaxPmapFunction_tp_new(PyTypeObject *subtype, PyObject *args, + PyObject *kwds) { + JaxPmapFunctionObject *self = + reinterpret_cast(subtype->tp_alloc(subtype, 0)); + if (!self) return nullptr; +#if PY_VERSION_HEX < 0x030C0000 + self->dict = nullptr; + self->weakrefs = nullptr; +#endif // PY_VERSION_HEX < 0x030C0000 + self->vectorcall = JaxPmapFunction_tp_vectorcall; + return reinterpret_cast(self); +} + +void JaxPmapFunction_tp_dealloc(PyObject *self) { + PyObject_GC_UnTrack(self); + PyTypeObject *tp = Py_TYPE(self); + JaxPmapFunctionObject *o = reinterpret_cast(self); + PyObject_ClearWeakRefs(self); +#if PY_VERSION_HEX < 0x030C0000 + Py_CLEAR(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + o->fun.~PmapFunction(); + tp->tp_free(self); + Py_DECREF(tp); +} + +int JaxPmapFunction_tp_traverse(PyObject *self, visitproc visit, void *arg) { + JaxPmapFunctionObject *o = reinterpret_cast(self); + // https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_traverse + Py_VISIT(Py_TYPE(self)); +#if PY_VERSION_HEX < 0x030C0000 + Py_VISIT(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_VisitManagedDict(self, visit, arg); +#else + PyObject_VisitManagedDict(self, visit, arg); +#endif // PY_VERSION_HEX < 0x030C0000 + Py_VISIT(o->fun.fun().ptr()); + Py_VISIT(o->fun.cache_miss().ptr()); + return 0; +} + +int JaxPmapFunction_tp_clear(PyObject *self) { + JaxPmapFunctionObject *o = reinterpret_cast(self); +#if PY_VERSION_HEX < 0x030C0000 + Py_CLEAR(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + o->fun.ClearPythonReferences(); + return 0; +} + +// Implements the Python descriptor protocol so PMAP-compiled functions can be +// used as bound methods. See: +// https://docs.python.org/3/howto/descriptor.html#functions-and-methods +PyObject *JaxPmapFunction_tp_descr_get(PyObject *self, PyObject *obj, + PyObject *type) { + if (obj == nullptr || obj == Py_None) { + Py_INCREF(self); + return self; + } + return PyMethod_New(self, obj); +} + +static PyGetSetDef JaxPmapFunction_tp_getset[] = { + // Having a __dict__ seems necessary to allow !functool.wraps to override + // __doc__. + {const_cast("__dict__"), PyObject_GenericGetDict, + PyObject_GenericSetDict, nullptr, nullptr}, + {nullptr, nullptr, nullptr, nullptr, nullptr}}; + +PyMemberDef JaxPmapFunction_members[] = { + {"__vectorcalloffset__", T_PYSSIZET, + static_cast(offsetof(JaxPmapFunctionObject, vectorcall)), + READONLY, nullptr}, +#if PY_VERSION_HEX < 0x030C0000 + {"__dictoffset__", T_PYSSIZET, + static_cast(offsetof(JaxPmapFunctionObject, dict)), READONLY, + nullptr}, + {"__weaklistoffset__", T_PYSSIZET, + static_cast(offsetof(JaxPmapFunctionObject, weakrefs)), + READONLY, nullptr}, +#endif // PY_VERSION_HEX < 0x030C0000 + {nullptr, 0, 0, 0, nullptr}, +}; + +PyType_Slot JaxPmapFunction_slots[] = { + {Py_tp_new, reinterpret_cast(JaxPmapFunction_tp_new)}, + {Py_tp_dealloc, reinterpret_cast(JaxPmapFunction_tp_dealloc)}, + {Py_tp_traverse, reinterpret_cast(JaxPmapFunction_tp_traverse)}, + {Py_tp_clear, reinterpret_cast(JaxPmapFunction_tp_clear)}, + {Py_tp_getset, reinterpret_cast(JaxPmapFunction_tp_getset)}, + {Py_tp_descr_get, reinterpret_cast(JaxPmapFunction_tp_descr_get)}, + {Py_tp_call, reinterpret_cast(PyVectorcall_Call)}, + {Py_tp_members, reinterpret_cast(JaxPmapFunction_members)}, + {0, nullptr}, +}; + +} // extern "C" + +nb::object MakePmapFunction( + nb::callable fun, nb::callable cache_miss, std::vector static_argnums, + nb::callable python_shard_arg_fallback, + xla::nb_class_ptr pytree_registry) { + nb::object obj = nb::steal(JaxPmapFunction_tp_new( + reinterpret_cast(JaxPmapFunction_Type), nullptr, + nullptr)); + JaxPmapFunctionObject *buf = + reinterpret_cast(obj.ptr()); + new (&buf->fun) PmapFunction( + std::move(fun), std::move(cache_miss), std::move(static_argnums), + std::move(python_shard_arg_fallback), std::move(pytree_registry)); + return obj; +} + +// Version numbers for the pickled representations. +// Increment these if changing them. +const int kPmapFunctionPickleVersion = 1; + +} // namespace + +void BuildPmapSubmodule(nb::module_ &m) { + nb::module_ pmap_lib = m.def_submodule("pmap_lib", "Jax C++ pmap library"); + + nb::class_ no_sharding(pmap_lib, "NoSharding"); + no_sharding.def(nb::init<>()) + .def("__getstate__", + [](const NoSharding &self) { return nb::make_tuple(); }) + .def("__setstate__", + [](NoSharding &self, nb::tuple t) { new (&self) NoSharding(); }) + .def("__repr__", [](const NoSharding &self) { return "NoSharding()"; }) + .def("__eq__", + [](const NoSharding &self, nb::object obj) { + return nb::isinstance(obj); + }) + .def("__hash__", [](const NoSharding &self) { + const size_t hash = absl::HashOf(self); + return nb::int_(hash); + }); + + nb::class_ chunked(pmap_lib, "Chunked"); + chunked.def(nb::init>()) + .def("__getstate__", + [](const Chunked &self) { return nb::make_tuple(self.chunks); }) + .def("__setstate__", + [](Chunked &self, nb::tuple t) { + new (&self) Chunked{nb::cast>(t[0])}; + }) + .def_ro("chunks", &Chunked::chunks) + .def("__repr__", + [](const Chunked &self) { + return absl::StrCat("Chunked(", absl::StrJoin(self.chunks, ","), + ")"); + }) + .def("__eq__", [](const Chunked &self, nb::object other) { + if (!nb::isinstance(other)) { + return false; + } + return self == nb::cast(other); + }); + + nb::class_ unstacked(pmap_lib, "Unstacked"); + unstacked.def(nb::init()) + .def("__getstate__", + [](const Unstacked &self) { return nb::make_tuple(self.size); }) + .def("__setstate__", + [](Unstacked &self, nb::tuple t) { + new (&self) Unstacked{nb::cast(t[0])}; + }) + .def_ro("size", &Unstacked::size) + .def("__repr__", + [](const Unstacked &x) { + return absl::StrCat("Unstacked(", x.size, ")"); + }) + .def("__eq__", [](const Unstacked &self, nb::object other) { + if (!nb::isinstance(other)) { + return false; + } + return self == nb::cast(other); + }); + + nb::class_ sharded_axis(pmap_lib, "ShardedAxis"); + sharded_axis.def(nb::init()) + .def("__getstate__", + [](const ShardedAxis &self) { return nb::make_tuple(self.axis); }) + .def("__setstate__", + [](ShardedAxis &self, nb::tuple t) { + new (&self) ShardedAxis{nb::cast(t[0])}; + }) + .def_ro("axis", &ShardedAxis::axis) + .def("__repr__", + [](const ShardedAxis &x) { + return absl::StrCat("ShardedAxis(axis=", x.axis, ")"); + }) + .def("__eq__", [](const ShardedAxis &self, const ShardedAxis &other) { + return self == other; + }); + + nb::class_ replicated(pmap_lib, "Replicated"); + replicated.def(nb::init()) + .def("__getstate__", + [](const Replicated &self) { return nb::make_tuple(self.replicas); }) + .def("__setstate__", + [](Replicated &self, nb::tuple t) { + new (&self) Replicated{nb::cast(t[0])}; + }) + .def_ro("replicas", &Replicated::replicas) + .def("__repr__", + [](const Replicated &x) { + return absl::StrCat("Replicated(replicas=", x.replicas, ")"); + }) + .def("__eq__", [](const Replicated &self, const Replicated &other) { + return self == other; + }); + + nb::class_ sharding_spec(pmap_lib, "ShardingSpec"); + sharding_spec + .def(nb::init(), nb::arg("sharding"), + nb::arg("mesh_mapping")) + .def("__getstate__", + [](const ShardingSpec &self) { + auto sharding = + xla::SpanToNbTuple(absl::MakeConstSpan(self.GetSharding())); + auto mesh_mapping = + xla::SpanToNbTuple(absl::MakeConstSpan(self.GetMeshMapping())); + return nb::make_tuple(sharding, mesh_mapping); + }) + .def("__setstate__", + [](ShardingSpec &self, nb::tuple t) { + new (&self) + ShardingSpec{nb::cast>(t[0]), + nb::cast>(t[1])}; + }) + .def_prop_ro( + "sharding", + [](const ShardingSpec &self) { + return xla::SpanToNbTuple(absl::MakeConstSpan(self.GetSharding())); + }) + .def_prop_ro("mesh_mapping", + [](const ShardingSpec &self) { + return xla::SpanToNbTuple( + absl::MakeConstSpan(self.GetMeshMapping())); + }) + .def("__eq__", [](const ShardingSpec &self, + const ShardingSpec &other) { return self == other; }) + .def("__hash__", [](const ShardingSpec &self) { + const size_t hash = absl::HashOf(self); + return nb::int_(hash); + }); + + // We need to use heap-allocated type objects because we want to add + // additional methods dynamically. + + std::string name = + absl::StrCat(nb::cast(m.attr("__name__")), ".PmapFunction"); + PyType_Spec pmap_function_spec = { +#if PY_VERSION_HEX < 0x030B0000 + // Work around for https://github.com/python/cpython/issues/89478 + // CPython 3.10 and earlier assume that the .name value remains alive + // forever. + /*.name=*/strdup(name.c_str()), +#else + /*.name=*/name.c_str(), +#endif // PY_VERSION_HEX < 0x030B0000 + /*.basicsize=*/static_cast(sizeof(JaxPmapFunctionObject)), + /*.itemsize=*/0, +#if PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_HAVE_VECTORCALL, +#else // PY_VERSION_HEX >= 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_HAVE_VECTORCALL | Py_TPFLAGS_MANAGED_DICT | + Py_TPFLAGS_MANAGED_WEAKREF, +#endif // PY_VERSION_HEX >= 0x030C0000 + /*.slots=*/JaxPmapFunction_slots, + }; + + JaxPmapFunction_Type = PyType_FromSpec(&pmap_function_spec); + if (!JaxPmapFunction_Type) { + throw nb::python_error(); + } + nb::object cfun = nb::borrow(JaxPmapFunction_Type); + + // Add PmapFunction to the _jax module so it can be pickled. + m.attr("PmapFunction") = cfun; + + cfun.attr("__signature__") = + xla::nb_property_readonly([](nb::handle self) -> nb::object { + PmapFunction *fun = xla::ValueOrThrow(AsPmapFunction(self)); + return fun->PythonSignature(); + }); + // Required by `post_hook`. + cfun.attr("_cache_miss") = + xla::nb_property_readonly([](nb::handle self) -> nb::object { + PmapFunction *fun = xla::ValueOrThrow(AsPmapFunction(self)); + return fun->cache_miss(); + }); + cfun.attr("__getstate__") = nb::cpp_function( + [](const PmapFunction::object &self) { + PmapFunction *fn = self.func(); + nb::dict pickle; + pickle["version"] = kPmapFunctionPickleVersion; + pickle["fun"] = fn->fun(); + pickle["cache_miss"] = fn->cache_miss(); + pickle["static_argnums"] = fn->static_argnums(); + pickle["python_shard_arg_fallback"] = fn->python_shard_arg_fallback(); + pickle["pytree_registry"] = nb::cast(fn->pytree_registry()); + return pickle; + }, + nb::is_method()); + cfun.attr("__setstate__") = nb::cpp_function( + [](PmapFunction::object &self, const nb::dict &pickle) { + int version = nb::cast(pickle["version"]); + if (version != kPmapFunctionPickleVersion) { + throw std::invalid_argument(absl::StrFormat( + "Invalid PmapFunction pickle version, got %d, expected %d. " + "Pickling/Unpickling jitted functions using different JAX " + "versions is not supported.", + version, kPmapFunctionPickleVersion)); + } + nb::callable fun = nb::cast(pickle["fun"]); + nb::callable cache_miss = nb::cast(pickle["cache_miss"]); + std::vector static_argnums = + nb::cast>(pickle["static_argnums"]); + nb::callable python_shard_arg_fallback = + nb::cast(pickle["python_shard_arg_fallback"]); + xla::nb_class_ptr pytree_registry = + nb::cast>( + pickle["pytree_registry"]); + new (&(reinterpret_cast(self.ptr())->fun)) + PmapFunction(std::move(fun), std::move(cache_miss), + std::move(static_argnums), + std::move(python_shard_arg_fallback), + std::move(pytree_registry)); + }, + nb::is_method()); + + // This is only for testing/debugging purposes. + cfun.attr("_cache_size") = + xla::nb_property_readonly([](nb::handle self) -> nb::object { + PmapFunction *fun = xla::ValueOrThrow(AsPmapFunction(self)); + return nb::cast(fun->cache_size()); + }); + + cfun.attr("_cache_clear") = nb::cpp_function( + [](nb::handle self) { + PmapFunction *fun = xla::ValueOrThrow(AsPmapFunction(self)); + fun->cache_clear(); + }, + nb::is_method()); + + cfun.attr("_debug_cache_keys") = nb::cpp_function( + [](nb::handle self) -> std::string { + PmapFunction *fun = xla::ValueOrThrow(AsPmapFunction(self)); + return fun->DebugCacheKeys(); + }, + nb::is_method()); + + pmap_lib.def( + "pmap", + [](nb::callable fun, nb::callable cache_miss, + std::vector static_argnums, nb::callable shard_arg_fallback, + nb::object pytree_registry) -> nb::object { + xla::nb_class_ptr registry = + nb::cast>(pytree_registry); + return MakePmapFunction( + std::move(fun), std::move(cache_miss), std::move(static_argnums), + std::move(shard_arg_fallback), std::move(registry)); + }, + nb::arg("fun"), nb::arg("cache_miss"), nb::arg("static_argnums"), + nb::arg("shard_arg_fallback"), nb::arg("pytree_registry")); +} + +} // namespace jax diff --git a/tests/ci_clangformat/pmap_lib.h b/tests/ci_clangformat/pmap_lib.h new file mode 100644 index 0000000..a02f7fd --- /dev/null +++ b/tests/ci_clangformat/pmap_lib.h @@ -0,0 +1,33 @@ +/* Copyright 2021 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PMAP_LIB_H_ +#define JAXLIB_PMAP_LIB_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +// TODO(jblespiau): The current implementation moves the Python logic to C++, +// as a preliminary step to executing the `pmap` execution path from C++. +// It implements the current Python behavior (thus, it may not be optimal, and +// we will be able to modify it later). + +namespace jax { + +void BuildPmapSubmodule(nanobind::module_ &m); + +} // namespace jax + +#endif // JAXLIB_PMAP_LIB_H_ diff --git a/tests/ci_clangformat/py_array.cc b/tests/ci_clangformat/py_array.cc new file mode 100644 index 0000000..2dfa553 --- /dev/null +++ b/tests/ci_clangformat/py_array.cc @@ -0,0 +1,2137 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_array.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/casts.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "jaxlib/guard_lib.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_device.h" +#include "jaxlib/py_device_list.h" +#include "jaxlib/py_values.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/sharding.h" +#include "jaxlib/to_ifrt_sharding.h" +#include "jaxlib/traceback.h" +#include "jaxlib/util.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/lru_cache.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/status_casters.h" +#include "xla/primitive_util.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/array_spec.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/remap_plan.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/nb_helpers.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/types.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/python/lib/core/numpy.h" // IWYU pragma: keep +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "llvm/Support/Casting.h" + +namespace xla { +namespace { + +namespace nb = nanobind; + +PjRtBuffer *GetPjrtBuffer(ifrt::Array *ifrt_array) { + auto *arr = llvm::dyn_cast_or_null(ifrt_array); + if (arr == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return arr->pjrt_buffers().front().get(); +} + +absl::StatusOr XlaDynamicShape(ifrt::Array *ifrt_array, + std::optional &scratch) { + auto *pjrt_buffer = GetPjrtBuffer(ifrt_array); + + if (!scratch) { + absl::Span dims; + std::optional> logical_dims_storage; + if (pjrt_buffer->has_dynamic_dimensions()) { + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(std::vector logical_dims, + pjrt_buffer->logical_dimensions()); + logical_dims_storage.emplace(std::move(logical_dims)); + } + dims = *logical_dims_storage; + } else { + dims = pjrt_buffer->dimensions(); + } + Shape shape = ShapeUtil::MakeShape(pjrt_buffer->element_type(), dims); + // TODO(b/327524065): fix this + *shape.mutable_layout() = pjrt_buffer->layout()->xla_layout(); + scratch = std::move(shape); + } + return &scratch.value(); +} + +ifrt::ArrayRef CreateIfRtArrayFromSingleDeviceShardedPyArrays( + nb_dtype dtype, absl::Span shape, + absl::Span py_arrays, const nb::object &sharding) { + const ifrt::MemoryKind dst_memory_kind = xla::GetMemoryKind(sharding); + + std::vector ifrt_arrays; + ifrt_arrays.reserve(py_arrays.size()); + absl::InlinedVector devices; + devices.reserve(py_arrays.size()); + absl::flat_hash_set device_set; + device_set.reserve(py_arrays.size()); + std::vector shapes; + shapes.reserve(py_arrays.size()); + + auto sharding_device_list = xla::GetIfrtDeviceList(sharding); + if (!sharding_device_list.ok()) { + // TODO(hyeontaek): Return a absl::Status. + throw nb::value_error(sharding_device_list.status().ToString().c_str()); + } + ifrt::Device *device = sharding_device_list.value()->devices().front(); + + // TODO(hyeontaek): Canonicalize every `ifrt::MemoryKind` at creation time to + // skip canonicalization here once JAX begins to do it for JAX shardings. + const ifrt::MemoryKind canonical_dst_memory_kind = + ifrt::CanonicalizeMemoryKind(dst_memory_kind, device); + for (const auto &py_array : py_arrays) { + if (py_array.num_shards() != 1) { + throw nb::value_error( + absl::StrFormat( + "When making an array from single-device arrays the input arrays " + "must have one shard each. An argument array had %d shard(s).", + py_array.num_shards()) + .c_str()); + } + ifrt_arrays.push_back(tsl::FormRef(py_array.ifrt_array())); + ifrt::Device *const device = + ifrt_arrays.back()->sharding().devices()->devices().front(); + devices.push_back(device); + device_set.insert(device); + shapes.push_back(ifrt_arrays.back()->shape()); + if (canonical_dst_memory_kind != + ifrt::CanonicalizeMemoryKind( + ifrt_arrays.back()->sharding().memory_kind(), device)) { + throw nb::value_error( + absl::StrFormat( + "Memory kind mismatch with PjRtBuffers. Got sharding with " + "memory kind '%v' and a buffer with memory_kind '%v'", + dst_memory_kind, ifrt_arrays.back()->sharding().memory_kind()) + .c_str()); + } + } + ifrt::DeviceListRef device_list = device->client()->MakeDeviceList(devices); + if (device_set.size() != device_list->size()) { + throw nb::value_error( + absl::StrFormat( + "When making an array from single-device arrays, the input arrays " + "must be from distinct devices, but got %v", + *device_list) + .c_str()); + } + + auto ifrt_dtype = DtypeToIfRtDType(dtype); + if (!ifrt_dtype.ok()) { + // TODO(hyeontaek): Return a absl::Status. + throw nb::value_error(ifrt_dtype.status().ToString().c_str()); + } + + absl::StatusOr ifrt_sharding = + sharding.type().is(jax::PmapSharding::type()) + ? xla::GetIfrtConcreteSharding(sharding, ifrt::Shape(shape), + std::move(shapes)) + : xla::GetIfrtHloSharding(sharding, ifrt::Shape(shape)); + if (!ifrt_sharding.ok()) { + // TODO(hyeontaek): Return a absl::Status. + throw nb::value_error(ifrt_sharding.status().ToString().c_str()); + } + // TODO(emilyaf): Always use `ifrt_dtype` once tokens are handled correctly. + ifrt::DType array_dtype = + ifrt_arrays.empty() ? ifrt_dtype.value() : ifrt_arrays[0]->dtype(); + absl::StatusOr ifrt_array = + device->client()->AssembleArrayFromSingleDeviceArrays( + array_dtype, ifrt::Shape(shape), *std::move(ifrt_sharding), + absl::MakeSpan(ifrt_arrays), ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); + if (!ifrt_array.ok()) { + // TODO(hyeontaek): Return a absl::Status. + throw nb::value_error(ifrt_array.status().ToString().c_str()); + } + return *std::move(ifrt_array); +} + +struct PyBaseArrayObject { + PyObject_HEAD; +#if PY_VERSION_HEX < 0x030C0000 + PyObject *weakrefs; +#endif // PY_VERSION_HEX < 0x030C0000 +}; + +extern "C" void PyBaseArray_tp_dealloc(PyBaseArrayObject *self) { + PyObject_GC_UnTrack(self); + PyObject_ClearWeakRefs((PyObject *)self); + PyTypeObject *tp = Py_TYPE(self); + tp->tp_free((PyObject *)self); + Py_DECREF(tp); +} + +extern "C" int PyBaseArray_tp_traverse(PyObject *self, visitproc visit, + void *arg) { + Py_VISIT(Py_TYPE(self)); + return 0; +} + +struct PyArrayObject { + PyObject_HEAD; +#if PY_VERSION_HEX < 0x030C0000 + PyObject *weakrefs; + PyObject *dict; +#endif // PY_VERSION_HEX < 0x030C0000 + bool initialized; + alignas(PyArray::Storage) char array_storage[sizeof(PyArray::Storage)]; +}; +static_assert(std::is_standard_layout::value); + +PyArray::Storage *GetPyArrayStorageFromObject(PyArrayObject *py_array_object) { + return std::launder( + reinterpret_cast(py_array_object->array_storage)); +} + +extern "C" PyObject *PyArray_tp_new(PyTypeObject *type, PyObject *, + PyObject *) { + PyObject *self = type->tp_alloc(type, 0); + auto *obj = reinterpret_cast(self); + obj->initialized = false; + return self; +} + +extern "C" void PyArray_tp_dealloc(PyObject *self) { + PyObject_GC_UnTrack(self); + PyTypeObject *tp = Py_TYPE(self); + auto *obj = reinterpret_cast(self); + + if (obj->initialized) { + GetPyArrayStorageFromObject(obj)->~PyArray_Storage(); + } + + PyObject_ClearWeakRefs(self); +#if PY_VERSION_HEX < 0x030C0000 + PyObject *&dict = *_PyObject_GetDictPtr(self); + Py_CLEAR(dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + + tp->tp_free(self); + Py_DECREF(tp); +} + +// dynamic_attr: Allow the garbage collector to traverse the internal instance +// `__dict__`. +extern "C" int PyArray_tp_traverse(PyObject *self, visitproc visit, void *arg) { +#if PY_VERSION_HEX < 0x030C0000 + PyObject *&dict = *_PyObject_GetDictPtr(self); + Py_VISIT(dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_VisitManagedDict(self, visit, arg); +#else + PyObject_VisitManagedDict(self, visit, arg); +#endif // PY_VERSION_HEX < 0x030C0000 + // https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_traverse + Py_VISIT(Py_TYPE(self)); + return 0; +} + +// dynamic_attr: Allow the GC to clear the dictionary. +extern "C" int PyArray_tp_clear(PyObject *self) { + switch (auto guard_level = jax::GetGarbageCollectArrayGuard(); guard_level) { + case jax::GarbageCollectionGuardLevel::kAllow: + break; + case jax::GarbageCollectionGuardLevel::kLog: + case jax::GarbageCollectionGuardLevel::kFatal: { + auto *obj = reinterpret_cast(self); + std::string traceback_str; + if (obj->initialized) { + auto traceback = GetPyArrayStorageFromObject(obj)->traceback; + if (traceback.has_value()) { + traceback_str = traceback.value()->ToString(); + } + } + auto error_msg = absl::StrCat( + "`jax.Array` was deleted by the Python garbage collector " + "instead of reference counting. Break the reference cycle " + "that delays the deletion of this `jax.Array` to avoid hogging " + "memory. Traceback: \n", + traceback_str.empty() ? "not available" : traceback_str); + if (guard_level == jax::GarbageCollectionGuardLevel::kFatal) { + Py_FatalError(error_msg.c_str()); + } else { + PyErr_SetString(PyExc_RuntimeError, error_msg.c_str()); + PyErr_Print(); + PyErr_Clear(); + } + break; + } + } +#if PY_VERSION_HEX < 0x030C0000 + PyObject *&dict = *_PyObject_GetDictPtr(self); + Py_CLEAR(dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + return 0; +} + +template +PyArray::Storage *Construct(PyArrayObject *self, Args &&...args) { + PyArray::Storage *out = + new (self->array_storage) PyArray::Storage(std::forward(args)...); + self->initialized = true; + return out; +} + +struct ShapedArrayCacheKey { + std::vector dims; + ifrt::DType dtype{ifrt::DType::kInvalid}; + bool weak_type; + + template + friend H AbslHashValue(H h, const ShapedArrayCacheKey &value) { + return H::combine(std::move(h), value.dims, value.dtype, value.weak_type); + } + bool operator==(const ShapedArrayCacheKey &other) const { + return dims == other.dims && dtype == other.dtype && + weak_type == other.weak_type; + } +}; + +// Constructing ShapedArrays has gotten slow. Cache it. +nb::object MakeShapedArrayCached(const ShapedArrayCacheKey &key) { + using CacheT = + LRUCache>>; + static nb::ft_mutex mu; + static auto *lru_list = new CacheT::LRUList(4096); + static auto *cache = new CacheT(lru_list); + + static const nb::object *shaped_array = []() -> nb::object * { + nb::object jax_core; + try { + jax_core = nb::module_::import_("jax.core"); + } catch (nb::python_error &e) { + return nullptr; + } + return new nb::object(jax_core.attr("ShapedArray")); + }(); + if (!shaped_array) { + return nb::none(); + } + + nb::ft_lock_guard lock(mu); + auto value = + cache->GetOrCreateIfAbsent(key, [](const ShapedArrayCacheKey &key) { + return std::make_shared>(); + }); + + if (!value->has_value()) { + nb_dtype dtype = + IfrtDtypeToDtypeWithTokenCanonicalization(key.dtype).value(); + nb::object aval = (*shaped_array)( + SpanToNbTuple(absl::Span( + key.dtype.kind() == ifrt::DType::kToken ? std::vector{0} + : key.dims)), + dtype, key.weak_type); + *value = aval; + return aval; + } + return **value; +} + +// Grouping key used by BatchedCopyToDeviceWithSharding. +// Defined outside of the function as required by templatized function +// `AbslHashValue`. +struct BatchedCopyToDeviceWithShardingKey { + ifrt::DeviceListRef src_devices; + ifrt::MemoryKind src_memory_kind; + ifrt::DeviceListRef dst_devices; + ifrt::MemoryKind dst_memory_kind; + ifrt::ArrayCopySemantics array_copy_semantics; + + bool operator==(const BatchedCopyToDeviceWithShardingKey &other) const { + return *src_devices == *other.src_devices && + src_memory_kind == other.src_memory_kind && + *dst_devices == *other.dst_devices && + dst_memory_kind == other.dst_memory_kind && + array_copy_semantics == other.array_copy_semantics; + } + + template + friend H AbslHashValue(H h, const BatchedCopyToDeviceWithShardingKey &key) { + return H::combine(std::move(h), key.src_devices, key.src_memory_kind, + key.dst_devices, key.dst_memory_kind, + key.array_copy_semantics); + } +}; + +} // namespace + +PyArray_Storage::PyArray_Storage( + nb::object aval, bool weak_type, xla::nb_dtype dtype, + std::vector shape, nb::object sharding, bool committed, + nb_class_ptr py_client, std::optional traceback, + ifrt::ArrayRef ifrt_array, xla::PjRtFuture<> result_status) + : aval(std::move(aval)), + weak_type(weak_type), + dtype(std::move(dtype)), + shape(std::move(shape)), + sharding(std::move(sharding)), + committed(committed), + py_client(std::move(py_client)), + traceback(std::move(traceback)), + ifrt_array(std::move(ifrt_array)), + result_status(std::move(result_status)) { + static_assert(PyClient::kNumArraysShards < + std::numeric_limits::max()); + thread_id_bucket = std::hash()(std::this_thread::get_id()) % + PyClient::kNumArraysShards; + + PyClient::ArraysShard &shard = this->py_client->arrays_[thread_id_bucket]; + nanobind::ft_lock_guard lock(shard.mutex); + next = shard.arrays; + shard.arrays = this; + if (next) { + next->prev = this; + } + prev = nullptr; +} + +void PyInit_helper(PyArray self, nb::object aval, nb::object sharding, + absl::Span py_arrays, bool committed) { + auto dtype = nb::cast(aval.attr("dtype")); + auto shape = nb::cast>(aval.attr("shape")); + auto py_device_list = nb::cast( + sharding.attr("_internal_device_list")); + nb_class_ptr py_client = py_device_list->py_client(); + auto ifrt_array = CreateIfRtArrayFromSingleDeviceShardedPyArrays( + dtype, shape, py_arrays, sharding); + Construct(reinterpret_cast(self.ptr()), aval, + nb::cast(aval.attr("weak_type")), std::move(dtype), + std::move(shape), std::move(sharding), committed, py_client, + Traceback::Get(), std::move(ifrt_array), xla::PjRtFuture<>()); +} + +void PyArray::PyInit(PyArray self, nb::object aval, nb::object sharding, + absl::Span py_arrays, bool committed, + bool skip_checks) { + if (skip_checks) { + PyInit_helper(self, aval, sharding, py_arrays, committed); + } else { + nb::object rearranged_arrays = + self.CheckAndRearrange(py_arrays, sharding, aval); + auto rearranged_py_arrays = + nb::cast>(rearranged_arrays); + PyInit_helper(self, aval, sharding, rearranged_py_arrays, committed); + } +} + +PyArray PyArray::MakeFromSingleDeviceArray( + nb_class_ptr py_client, std::optional traceback, + ifrt::ArrayRef ifrt_array, bool weak_type, bool committed, + xla::PjRtFuture<> result_status) { + if (!llvm::isa(ifrt_array->sharding())) { + throw XlaRuntimeError( + InvalidArgument("Constructing single device jax.Array from non-single " + "device ifrt array.")); + } + auto shape_span = ifrt_array->shape().dims(); + ShapedArrayCacheKey key; + key.dtype = ifrt_array->dtype(); + key.dims = key.dtype.kind() == ifrt::DType::kToken + ? std::vector{0} + : std::vector(shape_span.begin(), shape_span.end()); + key.weak_type = weak_type; + auto aval = MakeShapedArrayCached(key); + auto dtype = IfrtDtypeToDtypeWithTokenCanonicalization(key.dtype).value(); + const ifrt::MemoryKind memory_kind = ifrt_array->sharding().memory_kind(); + nb::object py_memory_kind = + (memory_kind.memory_kind().has_value()) + ? nb::object(nb::str(memory_kind.memory_kind()->data(), + memory_kind.memory_kind()->size())) + : nb::none(); + nb::object sharding = make_nb_class( + py_client, ifrt_array->sharding().devices(), std::move(py_memory_kind)); + return PyArray(std::move(aval), weak_type, dtype, std::move(key.dims), + std::move(sharding), std::move(py_client), + std::move(traceback), std::move(ifrt_array), committed, + /*skip_checks=*/true, std::move(result_status)); +} + +PyArray PyArray::MakeFromIfrtArrayAndSharding( + nb_class_ptr py_client, std::optional traceback, + ifrt::ArrayRef ifrt_array, nb::object sharding, bool weak_type, + bool committed, bool skip_checks) { + auto shape_span = ifrt_array->shape().dims(); + ShapedArrayCacheKey key; + key.dtype = ifrt_array->dtype(); + key.dims = key.dtype.kind() == ifrt::DType::kToken + ? std::vector{0} + : std::vector(shape_span.begin(), shape_span.end()); + key.weak_type = weak_type; + auto aval = MakeShapedArrayCached(key); + auto dtype = IfrtDtypeToDtypeWithTokenCanonicalization(key.dtype).value(); + return PyArray(std::move(aval), weak_type, dtype, std::move(key.dims), + std::move(sharding), std::move(py_client), + std::move(traceback), std::move(ifrt_array), committed, + skip_checks); +} + +PyArrayResultHandler::PyArrayResultHandler(nb::object aval, nb::object sharding, + bool committed, bool skip_checks) + : aval_(std::move(aval)), + sharding_(std::move(sharding)), + committed_(committed), + skip_checks_(skip_checks) { + weak_type_ = nb::cast(aval_.attr("weak_type")); + dtype_ = nb::cast(aval_.attr("dtype")); + shape_ = nb::cast>(aval_.attr("shape")); +} + +PyArray PyArrayResultHandler::Call(absl::Span py_arrays) const { + auto py_device_list = jax::GetPyDeviceList(sharding_); + if (!py_device_list.ok()) { + throw nb::value_error( + absl::StrCat("Failed to get py device list from sharding: ", + py_device_list.status().ToString()) + .c_str()); + } + return Call(py_device_list.value()->py_client(), + CreateIfRtArrayFromSingleDeviceShardedPyArrays( + dtype_, shape_, py_arrays, sharding_), + xla::PjRtFuture<>()); +} + +PyArray PyArrayResultHandler::Call(nb_class_ptr py_client, + ifrt::ArrayRef ifrt_array, + xla::PjRtFuture<> result_status) const { + return PyArray(aval_, weak_type_, dtype_, shape_, sharding_, + std::move(py_client), Traceback::Get(), std::move(ifrt_array), + committed_, skip_checks_, std::move(result_status)); +} + +PyArray PyArrayResultHandler::Call(PyArray py_array) const { + return Call(py_array.py_client(), tsl::FormRef(py_array.ifrt_array()), + xla::PjRtFuture<>()); +} + +PyArray::PyArray(nb::object aval, bool weak_type, nb_dtype dtype, + std::vector shape, nb::object sharding, + nb_class_ptr py_client, + std::optional traceback, + ifrt::ArrayRef ifrt_array, bool committed, bool skip_checks, + xla::PjRtFuture<> result_status) { + auto *self = + PyArray_tp_new(reinterpret_cast(type_), nullptr, nullptr); + m_ptr = self; + Construct(reinterpret_cast(self), std::move(aval), weak_type, + std::move(dtype), std::move(shape), std::move(sharding), committed, + std::move(py_client), std::move(traceback), std::move(ifrt_array), + std::move(result_status)); + + if (!skip_checks) { + this->attr("_arrays") = this->attr("_check_and_rearrange")( + this->attr("_arrays"), this->attr("_sharding"), this->attr("aval")); + } +} + +PyArray::Storage &PyArray::GetStorage() { + return *GetPyArrayStorageFromObject(reinterpret_cast(ptr())); +} + +const PyArray::Storage &PyArray::GetStorage() const { + return *GetPyArrayStorageFromObject(reinterpret_cast(ptr())); +} + +nb::object PyArray::CheckAndRearrange(const absl::Span py_arrays, + const nb::object sharding, + const nb::object aval) { + return this->attr("_check_and_rearrange")(py_arrays, sharding, aval); +} + +void PyArray::SetIfrtArray(ifrt::ArrayRef ifrt_array) { + GetStorage().ifrt_array = std::move(ifrt_array); +} + +const std::vector &PyArray::py_arrays_cached() { + auto &py_arrays = this->py_arrays(); + + if (py_arrays.empty()) { + auto ifrt_arrays = ifrt_array()->DisassembleIntoSingleDeviceArrays( + ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); + if (!ifrt_arrays.ok()) { + throw nb::value_error( + absl::StrCat("Failed to disassemble into single-device arrays: ", + ifrt_arrays.status().ToString()) + .c_str()); + } + py_arrays.reserve(ifrt_arrays->size()); + for (auto &ifrt_array : *ifrt_arrays) { + py_arrays.push_back(PyArray::MakeFromSingleDeviceArray( + py_client(), traceback(), std::move(ifrt_array), weak_type(), + committed(), result_status())); + } + } + + return py_arrays; +} + +nb::object PyArray::arrays() { + // For performance, we only keep pjrt buffers by default. But on python side + // "_arrays" returns PyArrays instead, and subsequent calls to "_arrays" + // should return the same PyArrays (to avoid duplicate device to host + // transfers). So we create PyArrays the first time it is called and reuse + // them later. + if (ifrt_array() == nullptr || ifrt_array()->IsDeleted()) return nb::none(); + + if (llvm::isa(&ifrt_array()->sharding())) { + std::vector py_arrays; + py_arrays.push_back(*this); + return nb::cast(py_arrays); + } + + return nb::cast(py_arrays_cached()); +} + +absl::Status PyArray::set_arrays(nb::object obj) { + if (obj.is_none()) { + SetIfrtArray(ifrt::ArrayRef()); + py_arrays().clear(); + return absl::OkStatus(); + } + + if (!nb::isinstance(obj)) { + return InvalidArgument("Unsupported arg when setting Array._arrays: %s", + nb::cast(nb::str(obj.type()))); + } + + nb::list list(obj); + + if (list.size() == 0) return absl::OkStatus(); + + SetIfrtArray(ifrt::ArrayRef()); + py_arrays().clear(); + std::vector ifrt_arrays; + ifrt_arrays.reserve(list.size()); + absl::InlinedVector devices; + devices.reserve(list.size()); + std::vector shapes; + shapes.reserve(list.size()); + for (nb::handle obj : list) { + if (obj.type().is(PyArray::type())) { + auto py_array = nb::borrow(obj); + if (py_array.py_client().get() != py_client().get()) { + return InvalidArgument("Client mismatch when assigning to _arrays."); + } + if (py_array.num_shards() != 1) { + return InvalidArgument("Wrong number of shards: %d", + py_array.num_shards()); + } + ifrt_arrays.push_back(tsl::FormRef(py_array.ifrt_array())); + devices.push_back( + ifrt_arrays.back()->sharding().devices()->devices().front()); + shapes.push_back(ifrt_arrays.back()->shape()); + } else { + return InvalidArgument("Unsupported arg when setting Array._arrays: %s", + nb::cast(nb::str(obj.type()))); + } + } + const ifrt::MemoryKind first_memory_kind = + ifrt_arrays.front()->sharding().memory_kind(); + // TODO(hyeontaek): Canonicalize every `ifrt::MemoryKind` at creation time to + // skip canonicalization here once JAX begins to do it for JAX shardings. + const ifrt::MemoryKind canonical_first_memory_kind = + ifrt::CanonicalizeMemoryKind( + first_memory_kind, + ifrt_arrays.front()->sharding().devices()->devices().front()); + for (const auto &ifrt_array : ifrt_arrays) { + if (canonical_first_memory_kind != + ifrt::CanonicalizeMemoryKind( + ifrt_array->sharding().memory_kind(), + ifrt_array->sharding().devices()->devices().front())) { + throw nb::value_error( + absl::StrFormat( + "Memory kind mismatch between single-device arrays. Got one " + "array with memory kind '%v' and another with memory_kind '%v'", + first_memory_kind, ifrt_array->sharding().memory_kind()) + .c_str()); + } + } + + TF_ASSIGN_OR_RETURN( + auto ifrt_sharding, + sharding().type().is(jax::PmapSharding::type()) + ? xla::GetIfrtConcreteSharding(sharding(), ifrt::Shape(shape()), + std::move(shapes)) + : xla::GetIfrtHloSharding(sharding(), ifrt::Shape(shape()))); + TF_ASSIGN_OR_RETURN( + auto array, + py_client()->ifrt_client()->AssembleArrayFromSingleDeviceArrays( + ifrt::Shape(shape()), std::move(ifrt_sharding), + absl::MakeSpan(ifrt_arrays), ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards)); + SetIfrtArray(std::move(array)); + return absl::OkStatus(); +} + +absl::StatusOr PyArray::FullyReplicatedShard() { + auto &cached = GetStorage().fully_replicated_array; + if (!cached.is_none()) { + return nb::cast(cached); + } + + if (ifrt_array() == nullptr) { + return InvalidArgument( + "FullyReplicatedShard() called on deleted or donated buffer"); + } + + TF_ASSIGN_OR_RETURN(auto fully_replicated_ifrt_shard, + ifrt_array()->FullyReplicatedShard( + ifrt::ArrayCopySemantics::kReuseInput)); + auto array = MakeFromSingleDeviceArray( + py_client(), traceback(), std::move(fully_replicated_ifrt_shard), + weak_type(), committed(), result_status()); + cached = array; + return nb::cast(cached); +} + +absl::Status PyArray::BlockUntilReady() const { + nb::gil_scoped_release gil_release; + if (ifrt_array() == nullptr) { + return InvalidArgument( + "BlockHostUntilReady() called on deleted or donated buffer"); + } + ifrt::Array *ifrt_array = this->ifrt_array(); + return AwaitBuffersReady(absl::MakeConstSpan(&ifrt_array, 1)); +} + +absl::StatusOr PyArray::GetOnDeviceSizeInBytes() { + if (ifrt_array() == nullptr) { + return InvalidArgument( + "GetOnDeviceSizeInBytes() called on deleted or donated buffer"); + } + + TF_ASSIGN_OR_RETURN(size_t shard_size, + GetPjrtBuffer(ifrt_array())->GetOnDeviceSizeInBytes()); + return shard_size * nb::len(nb::object(sharding().attr("device_set"))); +} + +absl::Status PyArray::BlockUntilResultStatusIsReady() { + auto &result_status = GetStorage().result_status; + // If the result_status future is not valid, this result did not come directly + // from a computation that returns tokens, so we don't wait for the status. + if (!result_status.IsValid()) { + return absl::OkStatus(); + } + if (!result_status.IsReady()) { + // Only release the gil if we need to Await(). + nb::gil_scoped_release release_gil; + BlockUntilReadyWithCancel(result_status); + return result_status.Await(); + } + return result_status.Await(); +} + +absl::StatusOr> +PyArray::SingleDeviceArrayToNumpyArrayDidCopy() { + TF_ASSIGN_OR_RETURN(auto arr, FullyReplicatedShard()); + auto result = arr.GetStorage().host_value.AsNumPyArray( + arr.GetStorage().dynamic_shape, arr.ifrt_array()); + TF_RETURN_IF_ERROR(arr.BlockUntilResultStatusIsReady()); + return result; +} + +absl::StatusOr PyArray::SingleDeviceArrayToNumpyArray() { + TF_ASSIGN_OR_RETURN(auto result, SingleDeviceArrayToNumpyArrayDidCopy()); + return result.first; +} + +absl::Status PyArray::CopySingleDeviceArrayToHostAsync() { + TF_ASSIGN_OR_RETURN(auto arr, FullyReplicatedShard()); + return arr.GetStorage().host_value.CopyToHostAsync( + arr.GetStorage().dynamic_shape, arr.ifrt_array()); +} + +absl::StatusOr PyArray::AssertUnsharded(absl::string_view api) { + if (ifrt_array() == nullptr) { + return InvalidArgument("%s( called on deleted or donated buffer", api); + } + + if (llvm::isa(&ifrt_array()->sharding())) { + return *this; + } + + auto &py_arrays = py_arrays_cached(); + if (py_arrays.size() != 1) { + return InvalidArgument("%s() is supported only for unsharded arrays.", api); + } + return py_arrays[0]; +} + +absl::StatusOr PyArray::UnsafeBufferPointer() { + TF_ASSIGN_OR_RETURN(auto arr, AssertUnsharded("UnsafeBufferPointer")); + + return py_client()->pjrt_client()->UnsafeBufferPointer( + GetPjrtBuffer(arr.ifrt_array())); +} + +nb::dict PyArray::CudaArrayInterface() { + auto arr_or_error = AssertUnsharded("UnsafeBufferPointer"); + if (!arr_or_error.ok()) { + throw nb::attribute_error( + "__cuda_array_interface__ is only supported for unsharded arrays."); + } + auto arr = *arr_or_error; + + ifrt::Array *ifrt_array = arr.ifrt_array(); + std::optional &scratch = arr.GetStorage().dynamic_shape; + auto *pjrt_buffer = GetPjrtBuffer(ifrt_array); + if (pjrt_buffer->client()->platform_id() != CudaId()) { + throw nb::attribute_error( + "__cuda_array_interface__ is only defined for NVidia GPU buffers."); + } + if (pjrt_buffer->IsTuple()) { + throw nb::attribute_error( + "__cuda_array_interface__ is only defined for array buffers."); + } + + switch (pjrt_buffer->element_type()) { + case PrimitiveType::PRED: + case PrimitiveType::S8: + case PrimitiveType::S16: + case PrimitiveType::S32: + case PrimitiveType::S64: + case PrimitiveType::U8: + case PrimitiveType::U16: + case PrimitiveType::U32: + case PrimitiveType::U64: + case PrimitiveType::F16: + case PrimitiveType::F32: + case PrimitiveType::F64: + case PrimitiveType::C64: + case PrimitiveType::C128: + break; + + default: + throw nb::attribute_error( + absl::StrFormat( + "__cuda_array_interface__ is not supported for %s buffers.", + PrimitiveType_Name(pjrt_buffer->element_type())) + .c_str()); + } + + nb::str typestr = + ValueOrThrow(TypeDescriptorForPrimitiveType(pjrt_buffer->element_type())); + + // TODO(b/327524065): use PjRtLayout directly instead of xla::Layout + Layout xla_layout = pjrt_buffer->layout()->xla_layout(); + if (!LayoutUtil::IsMonotonicWithDim0Major(xla_layout)) { + throw nb::attribute_error( + "__cuda_array_interface__ is only currently supported for " + "buffers in row-major order."); + } + + nb::dict result; + const auto *dynamic_shape = + ValueOrThrow(XlaDynamicShape(ifrt_array, scratch)); + result["shape"] = SpanToNbTuple(dynamic_shape->dimensions()); + result["typestr"] = std::move(typestr); + std::unique_ptr external_reference_hold = + ValueOrThrow(pjrt_buffer->AcquireExternalReference()); + const void *root_ptr = + external_reference_hold->OpaqueDeviceMemoryDataPointer(); + nb::tuple data = + nb::make_tuple(nb::int_(absl::bit_cast(root_ptr)), + nb::bool_(true) /* read-only */ + ); + result["data"] = std::move(data); + result["version"] = nb::int_(2); + return result; +} + +absl::StatusOr CudaArrayInterfaceToBuffer( + const nb::dict &cai, nb_class_ptr client, + std::optional device_id) { + if (!cai.contains("data")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `data`"); + } + if (!cai.contains("shape")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `shape`"); + } + if (!cai.contains("typestr")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `typestr`"); + } + if (!cai.contains("version")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `version`"); + } + auto version = nb::cast(cai["version"]); + if (version < 2 || version > 3) { + LOG(WARNING) << "CUDA Array Interface version " << version + << " support is undefined"; + } + auto data = nb::cast(cai["data"]); + auto data_value = nb::cast(data[0]); + void *data_ptr = reinterpret_cast(data_value); + auto dimensions = nb::cast>(cai["shape"]); + if (data_value == 0 && absl::c_find(dimensions, 0) == dimensions.end()) { + return absl::InvalidArgumentError( + "CUDA Array Interface `data`(=NULL) and `shape`(no zero-valued " + "dimensions) are inconsistent"); + } + auto ndim = dimensions.size(); + TF_ASSIGN_OR_RETURN( + PrimitiveType element_type, + DtypeToPrimitiveType(nb_dtype::from_args(cai["typestr"]))); + + if (!device_id.has_value()) { + throw XlaRuntimeError( + "This operation requires CUDA support from jaxlib or jax cuda plugin."); + } + TF_ASSIGN_OR_RETURN(auto device, + client->DeviceFromLocalHardwareId(*device_id)); + bool is_default_stream = + data_value == 0 || version == 2 || + (version == 3 && (!cai.contains("stream") || cai["stream"].is_none())); + TF_ASSIGN_OR_RETURN( + std::intptr_t stream, + ([is_default_stream, cai, device]() -> absl::StatusOr { + if (is_default_stream) { + return device->GetStreamForExternalReadyEvents(); + } else { + auto stream_ = nb::cast(cai["stream"]); + if (stream_ == 0) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not allow zero stream value"); + } + return stream_; + } + }())); + + std::vector minor_to_major(ndim); + if (cai.contains("strides") && !cai["strides"].is_none() && data_value != 0) { + std::iota(minor_to_major.begin(), minor_to_major.end(), 0); + auto strides = nb::cast>(cai["strides"]); + if (strides.size() != ndim) { + return absl::InvalidArgumentError( + "CUDA Array Interface `shape` and `strides` dimensionalities are " + "inconsistent"); + } + absl::c_sort(minor_to_major, [&](int a, int b) { + // If two dimensions have the same stride, prefer the major-to-minor + // interpretation of the ordering, since that's what JAX wants. + return (strides[a] == strides[b] ? b < a : strides[a] < strides[b]); + }); + int64_t stride = ShapeUtil::ByteSizeOfPrimitiveType(element_type); + for (int64_t d : minor_to_major) { + if (dimensions[d] > 1 && strides[d] != stride) { + return absl::UnimplementedError(absl::StrCat( + "Only arrays with trivial (compact) striding are supported; " + "i.e., arrays whose striding represents a transposition of the " + "underlying buffer but not broadcasting. Dimensions were: [%s], " + "strides were [%s].", + absl::StrJoin(dimensions, ","), absl::StrJoin(strides, ","))); + } + stride *= dimensions[d]; + } + } else { + std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); + } + Shape shape = ShapeUtil::MakeShapeWithDenseLayout(element_type, dimensions, + minor_to_major); + std::function on_delete_callback = []() {}; + auto *pjrt_device = + llvm::dyn_cast_or_null(device->device()); + if (pjrt_device == nullptr) { + return InvalidArgument( + "This operation is implemented for a PjRt-compatible backend only."); + } + TF_RET_CHECK(pjrt_device->IsAddressable()); + TF_ASSIGN_OR_RETURN( + auto pjrt_buffer, + device->client()->pjrt_client()->CreateViewOfDeviceBuffer( + static_cast(data_ptr), shape, + *pjrt_device->pjrt_device()->default_memory_space(), + on_delete_callback, + stream <= 2 ? std::nullopt : std::make_optional(stream))); + auto *ifrt_client = + llvm::dyn_cast_or_null(client->ifrt_client()); + if (ifrt_client == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + TF_ASSIGN_OR_RETURN(auto ifrt_array, + ifrt_client->CreatePjRtArray(std::move(pjrt_buffer))); + return PyArray::MakeFromSingleDeviceArray(std::move(client), Traceback::Get(), + std::move(ifrt_array), false, true); +} + +absl::Status PyArray::Delete() { + for (auto &arr : py_arrays()) { + TF_RETURN_IF_ERROR(arr.Delete()); + } + py_arrays().clear(); + if (ifrt_array() != nullptr) { + // We do not wait for the deletion to complete here. + // + // (1) Skipping blocking does not affect the correctness of deletion as long + // as the runtime preserves dispatch ordering of deletion w.r.t. other + // operations. + // + // (2) Synchronously waiting for the deletion to complete is very expensive + // when the deletion can return a status only after the underlying physical + // buffer has been deleted or a request must be processed via RPC, + // especially as this deletion is done per array. + ifrt_array()->Delete(); + SetIfrtArray(ifrt::ArrayRef()); + } + return absl::OkStatus(); +} + +bool PyArray::IsDeleted() const { + if (ifrt_array() == nullptr) { + return true; + } + + return ifrt_array()->IsDeleted(); +} + +PyArray PyArray::Clone() const { + auto array = tsl::FormRef(ifrt_array()); + auto *ifrt_client = py_client()->ifrt_client(); + ifrt::ArrayRef out = + ifrt_client + ->CopyArrays(absl::MakeSpan(&array, 1), /*devices=*/std::nullopt, + /*memory_kind=*/std::nullopt, + ifrt::ArrayCopySemantics::kReuseInput) + .value() + .front(); + return PyArray(aval(), weak_type(), dtype(), + std::vector(shape().begin(), shape().end()), + sharding(), py_client(), traceback(), std::move(out), + committed(), /*skip_checks=*/true, result_status()); +} + +nb::handle PyArray::Storage::AsHandle() { + return reinterpret_cast(reinterpret_cast(this) - + offsetof(PyArrayObject, array_storage)); +} + +PyArray::Storage::~PyArray_Storage() { + CHECK(PyGILState_Check()); + if (py_client) { + PyClient::ArraysShard &shard = py_client->arrays_[thread_id_bucket]; + nanobind::ft_lock_guard lock(shard.mutex); + if (shard.arrays == this) { + shard.arrays = next; + } + if (prev) { + prev->next = next; + } + if (next) { + next->prev = prev; + } + } + // Release GIL and then explicitly destroy `ifrt_array` to prevent deadlock on + // CPU backend caused by interactions between argument donations and host + // callbacks. + nb::gil_scoped_release gil_release; + ifrt_array.reset(); +} + +absl::StatusOr> PyArray::BatchedCopyToDeviceWithSharding( + absl::Span py_arrays, + absl::Span dst_device_lists, + absl::Span dst_shardings, + absl::Span array_copy_semantics) { + if (py_arrays.empty()) { + return std::vector(); + } + + TF_RET_CHECK(py_arrays.size() == dst_device_lists.size()); + TF_RET_CHECK(py_arrays.size() == dst_shardings.size()); + + ifrt::Client *const client = py_arrays.front().ifrt_array()->client(); + std::vector results(py_arrays.size()); + + // Arrays to be copied, grouped by source/destination devices and memory + // kinds. The grouping is enforced by `ifrt::Client::CopyArrays()`. + struct Batch { + std::vector indexes; + std::vector ifrt_arrays; + }; + absl::flat_hash_map batches; + + auto traceback = Traceback::Get(); + for (int i = 0; i < py_arrays.size(); ++i) { + const auto &py_array = py_arrays[i]; + const auto &dst_sharding = dst_shardings[i]; + const auto &array_cs = array_copy_semantics[i]; + + auto *ifrt_array_ptr = py_array.ifrt_array(); + const ifrt::DeviceListRef &src_devices = + ifrt_array_ptr->sharding().devices(); + const ifrt::DeviceListRef &dst_devices = dst_device_lists[i]; + + ifrt::MemoryKind src_memory_kind = + ifrt::CanonicalizeMemoryKind(ifrt_array_ptr->sharding().memory_kind(), + src_devices->devices().front()); + ifrt::MemoryKind dst_memory_kind = ifrt::CanonicalizeMemoryKind( + xla::GetMemoryKind(dst_sharding), dst_devices->devices().front()); + + if (*src_devices == *dst_devices && src_memory_kind == dst_memory_kind && + array_cs == ifrt::ArrayCopySemantics::kReuseInput) { + if (jax::ShardingEqual(py_array.sharding(), dst_sharding)) { + results[i] = py_arrays[i]; + } else { + absl::Span shape_span = py_array.shape(); + // We can reuse the input array despite the sharding being different. + // This is because this code expects no resharding is necessary, which + // has been verified by the code invoking this method. + results[i] = + PyArray(py_array.aval(), py_array.weak_type(), py_array.dtype(), + std::vector(shape_span.begin(), shape_span.end()), + dst_sharding, py_array.py_client(), traceback, + tsl::FormRef(ifrt_array_ptr), py_array.committed(), + /*skip_checks=*/true, py_array.result_status()); + } + continue; + } + + auto transfer_guard_formatter = [&py_array, &dst_sharding] { + return absl::StrCat( + "aval=", nb::cast(nb::repr(py_array.aval())), + ", sharding=", + nb::cast(nb::repr(py_array.sharding())), + ", dst_sharding=", + nb::cast(nb::repr(dst_sharding))); + }; + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToDeviceToDevice(transfer_guard_formatter)); + + Batch &batch = batches[BatchedCopyToDeviceWithShardingKey{ + src_devices, src_memory_kind, dst_devices, dst_memory_kind, array_cs}]; + batch.indexes.push_back(i); + batch.ifrt_arrays.push_back(tsl::FormRef(ifrt_array_ptr)); + } + + std::vector> ifrt_arrays; + { + GlobalPyRefManager()->CollectGarbage(); + nb::gil_scoped_release gil_release; + + for (auto &[key, batch] : batches) { + TF_ASSIGN_OR_RETURN( + auto copied, + client->CopyArrays( + absl::MakeSpan(batch.ifrt_arrays), + // All arrays in `batch` have the same `key.dst_devices` and + // `key.dst_memory_kind` due to the grouping above. + key.dst_devices, key.dst_memory_kind, key.array_copy_semantics)); + for (int i = 0; i < batch.indexes.size(); ++i) { + ifrt_arrays.push_back( + std::make_pair(batch.indexes[i], std::move(copied[i]))); + } + } + } + + for (auto &[i, ifrt_array] : ifrt_arrays) { + const auto &py_array = py_arrays[i]; + absl::Span shape_span = py_array.shape(); + results[i] = + PyArray(py_array.aval(), py_array.weak_type(), py_array.dtype(), + std::vector(shape_span.begin(), shape_span.end()), + dst_shardings[i], py_array.py_client(), traceback, + std::move(ifrt_array), py_array.committed(), + /*skip_checks=*/true, py_array.result_status()); + } + return results; +} + +absl::StatusOr PyArray::BatchedDevicePut( + nb::object aval, nb::object sharding, std::vector xs, + absl::Span dst_devices, bool committed, + bool force_copy, PjRtClient::HostBufferSemantics host_buffer_semantics, + bool jax_enable_x64) { + if (dst_devices.size() != xs.size()) { + throw nb::value_error( + absl::StrCat("Argument sizes (xs and devices) must match %zu vs %zu", + dst_devices.size(), xs.size()) + .c_str()); + } + for (const PyDevice *device : dst_devices) { + if (device->client().get() == nullptr) { + return InvalidArgument("Cannot copy to unattached devices."); + } + } + auto transfer_guard_formatter = [&aval, &sharding] { + return absl::StrCat( + "aval=", nb::cast(nb::repr(aval)), + ", dst_sharding=", nb::cast(nb::repr(sharding))); + }; + + GlobalPyRefManager()->CollectGarbage(); + + auto n_devices = dst_devices.size(); + + DevicePutOptions options; + options.squash_64bit_types = !jax_enable_x64; + options.allow_zero_copy = + (!force_copy && (host_buffer_semantics == + ifrt::Client::HostBufferSemantics::kImmutableZeroCopy)); + + std::vector ifrt_arrays; + + absl::InlinedVector devices; + devices.reserve(n_devices); + std::vector shapes; + shapes.reserve(n_devices); + + std::vector args; + args.reserve(xs.size()); + for (const nb::object &x : xs) { + if (PyArray::IsPyArray(x)) { + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToDeviceToDevice(transfer_guard_formatter)); + } else { + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToHostToDevice(transfer_guard_formatter)); + } + args.push_back(x); + } + auto weak_type = nb::cast(aval.attr("weak_type")); + auto dtype = aval.attr("dtype"); + auto shape = nb::cast>(aval.attr("shape")); + TF_ASSIGN_OR_RETURN(nb_class_ptr py_device_list, + jax::GetPyDeviceList(sharding)); + + TF_ASSIGN_OR_RETURN( + DevicePutResult device_put_result, + DevicePutWithSharding(args, py_device_list->py_client()->ifrt_client(), + dtype, shape, sharding, options)); + + return PyArray(aval, weak_type, dtype, std::move(shape), std::move(sharding), + py_device_list->py_client(), Traceback::Get(), + std::move(device_put_result.ifrt_array), committed, + /*skip_checks=*/true); +} + +absl::StatusOr PyArray::ReorderShards( + PyArray x, nanobind::object dst_sharding, + ifrt::ArrayCopySemantics array_copy_semantics) { + xla::ifrt::Array *ifrt_array_ptr = x.ifrt_array(); + if (ifrt_array_ptr == nullptr) { + return absl::InvalidArgumentError( + "Reorder() called on deleted or donated buffer"); + } + + ifrt::Client *const client = ifrt_array_ptr->client(); + + const auto &device_list = ifrt_array_ptr->sharding().devices(); + TF_ASSIGN_OR_RETURN(auto dst_device_list, GetIfrtDeviceList(dst_sharding)); + if (device_list->AddressableDeviceList()->size() != + dst_device_list->AddressableDeviceList()->size()) { + return absl::InvalidArgumentError(absl::StrCat( + "Array is expected to have ", + dst_device_list->AddressableDeviceList()->size(), + " addressable shards, but has ", + device_list->AddressableDeviceList()->size(), " addressable shards")); + } + + TF_ASSIGN_OR_RETURN( + xla::ifrt::ShardingRef dst_ifrt_sharding, + GetIfrtConcreteEvenSharding(dst_sharding, ifrt_array_ptr->dtype(), + ifrt_array_ptr->shape())); + + xla::ifrt::ArrayRef new_ifrt_array; + { + nb::gil_scoped_release gil_release; + + const absl::Span addressable_devices = + device_list->AddressableDeviceList()->devices(); + const absl::Span dst_addressable_devices = + dst_device_list->AddressableDeviceList()->devices(); + + absl::flat_hash_map device_id_to_array_shard_index; + device_id_to_array_shard_index.reserve(dst_addressable_devices.size()); + for (int i = 0; i < dst_addressable_devices.size(); ++i) { + const int device_id = dst_addressable_devices[i]->Id().value(); + const bool inserted = + device_id_to_array_shard_index.insert({device_id, i}).second; + if (!inserted) { + return absl::InvalidArgumentError( + absl::StrCat("Sharding contains duplicate device id=", device_id)); + } + } + + std::vector from_shard_indices; + from_shard_indices.reserve(addressable_devices.size()); + std::vector to_shard_indices; + to_shard_indices.reserve(dst_addressable_devices.size()); + for (int i = 0; i < dst_addressable_devices.size(); ++i) { + from_shard_indices.push_back(i); + const int shard_device_id = addressable_devices[i]->Id().value(); + const auto it = device_id_to_array_shard_index.find(shard_device_id); + if (it == device_id_to_array_shard_index.end()) { + return absl::InvalidArgumentError(absl::StrCat( + "Array shard ", i, " is on device id=", shard_device_id, + ", but sharding does not have a shard on that device.")); + } + to_shard_indices.push_back(it->second); + } + + auto mappings = + std::make_shared>(); + { + auto &mapping = mappings->emplace_back(); + mapping.in_array = 0; + mapping.out_array = 0; + mapping.from.reserve(dst_addressable_devices.size()); + mapping.to.reserve(dst_addressable_devices.size()); + for (int64_t i = 0; i < dst_addressable_devices.size(); ++i) { + mapping.from.push_back(xla::ifrt::RemapPlan::Interval{ + from_shard_indices[i], from_shard_indices[i] + 1, 1}); + mapping.to.push_back(xla::ifrt::RemapPlan::Interval{ + to_shard_indices[i], to_shard_indices[i] + 1, 1}); + } + } + + xla::ifrt::RemapPlan plan = { + /*input_specs=*/{xla::ifrt::ArraySpec{ + /*dtype=*/ifrt_array_ptr->dtype(), + /*shape=*/ifrt_array_ptr->shape(), + /*sharding=*/ifrt_array_ptr->shared_ptr_sharding()}}, + /*output_specs=*/ + {xla::ifrt::ArraySpec{/*dtype=*/ifrt_array_ptr->dtype(), + /*shape=*/ifrt_array_ptr->shape(), + /*sharding=*/std::move(dst_ifrt_sharding)}}, + /*mappings=*/std::move(mappings), + }; + DCHECK_OK(plan.Validate()); + std::vector input; + input.push_back(tsl::FormRef(ifrt_array_ptr)); + TF_ASSIGN_OR_RETURN( + auto remapped, + client->RemapArrays(plan, absl::MakeSpan(input), array_copy_semantics)); + + TF_RET_CHECK(remapped.size() == 1); + new_ifrt_array = std::move(remapped.front()); + } + + return xla::PyArray(nb::borrow(x.aval().ptr()), x.weak_type(), + nb::borrow(x.dtype().ptr()), + std::vector(x.shape().begin(), x.shape().end()), + std::move(dst_sharding), x.py_client(), x.traceback(), + std::move(new_ifrt_array), + /*committed=*/true, + /*skip_checks=*/true); +} + +absl::Status PyArray::BatchedBlockUntilReady(std::vector objs) { + // Create ready futures for all arrays before blocking on their readiness. + // This helps reduce the latency in some backend implementations where + // querying readiness of an array is not free. + + std::vector ifrt_arrays; + ifrt_arrays.reserve(objs.size()); + for (nb::handle obj : objs) { + if (obj.type().is(PyArray::type())) { + auto py_array = nb::borrow(obj); + ifrt::Array *const ifrt_array = py_array.ifrt_array(); + if (ifrt_array == nullptr) { + return absl::InvalidArgumentError( + "BlockHostUntilReady() called on deleted or donated buffer"); + } + ifrt_arrays.push_back(ifrt_array); + } else { + return absl::InvalidArgumentError( + "PyArray::BatchedBlockUntilReady can take PyArray only"); + } + } + + GlobalPyRefManager()->CollectGarbage(); + nb::gil_scoped_release gil_release; + return AwaitBuffersReady(absl::MakeConstSpan(ifrt_arrays)); +} + +absl::Status PyArray::ReplaceWithAlias(PyArray o) { + auto &storage = GetStorage(); + auto &o_storage = o.GetStorage(); + if (storage.py_client.get() != o_storage.py_client.get()) { + return absl::InvalidArgumentError( + "Unable to replace a PyArray with a PyArray from a different client."); + } + storage.aval = o_storage.aval; + storage.weak_type = o_storage.weak_type; + storage.dtype = o_storage.dtype; + storage.shape = o_storage.shape; + storage.sharding = o_storage.sharding; + storage.npy_value = o_storage.npy_value; + storage.committed = o_storage.committed; + storage.traceback = o_storage.traceback; + storage.ifrt_array = o_storage.ifrt_array; + storage.fully_replicated_array = o_storage.fully_replicated_array; + storage.py_arrays = o_storage.py_arrays; + storage.host_value.Clear(); + storage.dynamic_shape = o_storage.dynamic_shape; + storage.result_status = o_storage.result_status; + + return absl::OkStatus(); +} + +std::vector PyClient::LiveArrays() const { + std::vector result; + for (auto &shard : arrays_) { + nb::ft_lock_guard lock(shard.mutex); + for (PyArray::Storage *array = shard.arrays; array; array = array->next) { + bool all_deleted = + (array->ifrt_array == nullptr || array->ifrt_array->IsDeleted()); + if (!all_deleted) { + result.push_back(nb::borrow(array->AsHandle())); + } + } + } + return result; +} + +// PEP 3118 buffer protocol implementation. + +namespace { + +// Extra data to be kept alive by the consumer of the buffer protocol. +struct ExtraBufferInfo { + explicit ExtraBufferInfo( + std::shared_ptr buffer, + std::unique_ptr external_reference_hold) + : buffer(std::move(buffer)), + external_reference_hold(std::move(external_reference_hold)) {} + + std::vector strides; + // We keep an external reference hold to the PjRtBuffer. This prevents a + // use-after-free in the event that Delete() is called on a buffer with an + // live buffer protocol view. It does however mean that Delete() sometimes + // won't actually delete immediately. + std::shared_ptr buffer; + std::unique_ptr external_reference_hold; +}; + +// The default layout of a non-tuple array should have major-to-minor layout +// and no tiles. +bool HasDefaultLayout(const Layout &layout) { + return LayoutUtil::IsMonotonicWithDim0Major(layout) && layout.tiles().empty(); +} + +int PyArray_bf_getbuffer(PyObject *exporter, Py_buffer *view, int flags) { + absl::Status status = [&]() -> absl::Status { + PyArray py_array = nb::borrow(exporter); + if (py_array.ifrt_array() == nullptr) { + // TODO(phawkins): why is this happening? + return InvalidArgument("Array is null"); + } + if (!llvm::isa(py_array.ifrt_array())) { + return InvalidArgument("Only local arrays are supported, got %s", + py_array.ifrt_array()->DebugString()); + } + auto *array = + static_cast(py_array.ifrt_array()); + absl::Span> buffers = + array->pjrt_buffers(); + + if (buffers.empty()) { + return InvalidArgument("Array has no buffers."); + } + PjRtBuffer &buffer = *buffers.front(); + if (!buffer.IsOnCpu()) { + return InvalidArgument( + "Python buffer protocol is only defined for CPU buffers."); + } + + if (buffers.size() != 1) { + return InvalidArgument( + "Python buffer protocol is only defined for buffers with a single " + "shard."); + } + if (!py_array.sharding().type().is(jax::SingleDeviceSharding::type())) { + return InvalidArgument( + "Python buffer protocol is only defined for single-device sharded " + "buffers."); + } + + const char *format = + PEP3118FormatDescriptorForPrimitiveType(buffer.element_type()); + // It isn't an option for us to export unknown types as, say, bytes. When + // converting an object to an ndarray, NumPy tries the buffer protocol + // first. We very much want NumPy to fail and fall back to using + // __array__, which allows us to handle custom dtypes correctly. + if (!format) { + return InvalidArgument( + "Buffers of type %s are not supported by the Python buffer protocol.", + PrimitiveType_Name(buffer.element_type())); + } + + std::unique_ptr external_reference_hold; + { + // We call BlockHostUntilReady() below, which may block. + nb::gil_scoped_release gil_release; + + if (buffer.IsTuple()) { + return InvalidArgument( + "Python buffer protocol is only defined for array buffers."); + } + if ((flags & PyBUF_WRITEABLE) == PyBUF_WRITEABLE) { + return InvalidArgument("XLA buffers are read-only."); + } + TF_ASSIGN_OR_RETURN(external_reference_hold, + buffer.AcquireExternalReference()); + if (buffer.IsDeleted()) { + return InvalidArgument("Deleted buffer used in buffer protocol."); + } + + // TODO(b/327524065): use PjRtLayout directly instead of xla::Layout + Layout xla_layout = buffer.layout()->xla_layout(); + + if (((flags & PyBUF_C_CONTIGUOUS) == PyBUF_C_CONTIGUOUS || + (flags & PyBUF_STRIDES) == PyBUF_ND) && + !LayoutUtil::IsMonotonicWithDim0Major(xla_layout)) { + return InvalidArgument("Buffer is not in C-contiguous layout."); + } else if ((flags & PyBUF_F_CONTIGUOUS) == PyBUF_F_CONTIGUOUS && + !LayoutUtil::IsMonotonicWithDim0Minor(xla_layout)) { + return InvalidArgument("Buffer is not in F-contiguous layout."); + } else if ((flags & PyBUF_ANY_CONTIGUOUS) == PyBUF_ANY_CONTIGUOUS && + !LayoutUtil::IsMonotonicWithDim0Major(xla_layout) && + !LayoutUtil::IsMonotonicWithDim0Minor(xla_layout)) { + return InvalidArgument("Buffer is not in contiguous layout."); + } else if (!HasDefaultLayout(xla_layout)) { + // Fail and fall back to using __array__ if the CPU buffer has a device + // specific layout. For instance, this happens for host buffers in + // pinned memories of the TPU device. + return InvalidArgument( + "Buffer is potentially a device buffer with non default layout."); + } + TF_RETURN_IF_ERROR(buffer.GetReadyFuture().Await()); + } + + // We must hold the GIL (or at least prevent Python GC) while writing to the + // view object, see https://github.com/python/cpython/issues/130409. + std::memset(view, 0, sizeof(Py_buffer)); + const void *root_ptr = + external_reference_hold->OpaqueDeviceMemoryDataPointer(); + view->buf = const_cast(root_ptr); + auto extra = std::make_unique( + buffers.front(), std::move(external_reference_hold)); + view->itemsize = ShapeUtil::ByteSizeOfPrimitiveType(buffer.element_type()); + TF_ASSIGN_OR_RETURN(view->len, buffer.GetOnDeviceSizeInBytes()); + view->readonly = 1; + if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) { + view->format = const_cast(format); + } + if ((flags & PyBUF_ND) == PyBUF_ND) { + view->ndim = buffer.dimensions().size(); + static_assert(sizeof(int64_t) == sizeof(Py_ssize_t), + "Py_ssize_t must be 64 bits"); + if (view->ndim != 0) { + view->shape = reinterpret_cast( + const_cast(buffer.dimensions().data())); + if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) { + extra->strides = + ByteStridesForShape(buffer.element_type(), buffer.dimensions(), + buffer.layout()->xla_layout()); + view->strides = reinterpret_cast( + const_cast(extra->strides.data())); + } + } + } + view->internal = extra.release(); + return absl::OkStatus(); + }(); + if (!status.ok()) { + // numpy.asarray(...) eats the PyExc_BufferError. Adding a log here helps + // debugging when the error really occurs. + VLOG(1) << "Buffer Protocol Error: " << status; + PyErr_SetString(PyExc_BufferError, status.ToString().c_str()); + return -1; + } + view->obj = exporter; + Py_INCREF(view->obj); + return 0; +} + +void PyArray_bf_releasebuffer(PyObject *, Py_buffer *buffer) { + auto extra = static_cast(buffer->internal); + delete extra; +} + +// Returns if shape has a major-to-minor layout. +bool HasMajorToMinorLayout(const xla::Shape &shape) { + if (shape.has_layout()) { + for (int i = 0; i < shape.layout().minor_to_major_size(); ++i) { + if (shape.layout().minor_to_major(i) != + shape.layout().minor_to_major_size() - 1 - i) { + return false; + } + } + } + return true; +} + +// Returns byte_strides if shape has a non-major-to-minor layout. +std::optional> ByteStridesOrDefaultForShapeInt64( + const Shape &shape) { + if (!shape.has_layout() || HasMajorToMinorLayout(shape)) { + return std::nullopt; + } + return ByteStridesForShape(shape); +} + +bool IsZeroCopyableCpuBuffer(const PjRtBuffer *buf) { + // For CPU buffers with device-specific layouts, we must delinearize + // to unpack the array. This could happen for the host buffer + // pre-mapped to the TPU device, a.k.a., pinned host buffers for the + // device. + bool has_default_layout = + buf->layout() == nullptr || HasDefaultLayout(buf->layout()->xla_layout()); + // On CPU for values >= 8 bits, we can return the value in a zero-copy way. + // For sub-byte values, we must copy in order to unpack the array. + return buf->IsOnCpu() && + !primitive_util::IsSubByteNonPredType(buf->element_type()) && + has_default_layout; +} +} // namespace + +PyHostValue::PyHostValue() = default; +PyHostValue::~PyHostValue() = default; + +absl::StatusOr> PyHostValue::AsNumPyArray( + std::optional &dynamic_shape_holder, ifrt::Array *ifrt_array) { + if (ifrt_array->IsDeleted()) { + return InvalidArgument("DeviceArray has been deleted."); + } + // The only `jax.Array` with token-shape buffer is the one wrapped by + // `jax.core.Token`. Since it is an internal implementation detail, we + // don't support converting it to a numpy array. + if (ifrt_array->dtype().kind() == ifrt::DType::kToken) { + return InvalidArgument( + "Cannot convert a token-shape buffer to a numpy array."); + } + auto *arr = llvm::dyn_cast_or_null(ifrt_array); + if (arr != nullptr) { + auto *pjrt_buffer = arr->pjrt_buffers().front().get(); + TF_RET_CHECK(!pjrt_buffer->IsTuple()); + // On CPU for values >= 8 bits, we can return the value in a zero-copy way. + // For sub-byte values, we must copy in order to unpack the array. + if (IsZeroCopyableCpuBuffer(pjrt_buffer)) { + TF_ASSIGN_OR_RETURN(const auto *shape, + XlaDynamicShape(ifrt_array, dynamic_shape_holder)); + TF_ASSIGN_OR_RETURN(nb_dtype dtype, + PrimitiveTypeToNbDtype(shape->element_type())); + // Objects that must be kept alive while the array is alive. + struct Hold { + ifrt::ArrayRef buffer; + std::unique_ptr external_reference_hold; + }; + auto hold = std::make_unique(); + hold->buffer = tsl::FormRef(ifrt_array); + auto *hold_ptr = hold.release(); + nb::capsule hold_capsule( + hold_ptr, [](void *h) noexcept { delete static_cast(h); }); + { + // Release the GIL as `AcquireExternalReference` may block. + nb::gil_scoped_release gil; + TF_ASSIGN_OR_RETURN(hold_ptr->external_reference_hold, + pjrt_buffer->AcquireExternalReference()); + auto fut = ifrt_array->GetReadyFuture(); + BlockUntilReadyWithCancel(fut); + TF_RETURN_IF_ERROR(fut.Await()); + } + void *data = + hold_ptr->external_reference_hold->OpaqueDeviceMemoryDataPointer(); + nb_numpy_ndarray array(dtype, shape->dimensions(), + ByteStridesForShape(*shape), data, hold_capsule); + array.attr("flags").attr("writeable") = nb::bool_(false); + return std::make_pair(array, false); + } + } + + TF_RETURN_IF_ERROR(CopyToHostAsync(dynamic_shape_holder, ifrt_array)); + if (!ready_.IsReady()) { + nb::gil_scoped_release gil; + BlockUntilReadyWithCancel(ready_); + TF_RETURN_IF_ERROR(ready_.Await()); + } else { + TF_RETURN_IF_ERROR(ready_.Await()); + } + if (string_array_contents_ != nullptr) { + TF_RETURN_IF_ERROR(ConvertStringArrayContentsToNumpyArray(ifrt_array)); + } + return std::make_pair(value_, true); +} + +absl::Status PyHostValue::ConvertStringArrayContentsToNumpyArray( + ifrt::Array *ifrt_array) { +#ifdef NPY_2_0_API_VERSION + if (PyArray_RUNTIME_VERSION < NPY_2_0_API_VERSION) { + return absl::FailedPreconditionError( + absl::StrCat("String arrays are not supported in NumPy version: ", + PyArray_RUNTIME_VERSION)); + } + auto numpy_dtype = nb::steal( + reinterpret_cast(PyArray_DescrFromType(NPY_VSTRING))); + value_ = nb_numpy_ndarray(numpy_dtype, ifrt_array->shape().dims(), + /*strides=*/std::nullopt); + + auto dst_py_array_obj = reinterpret_cast<::PyArrayObject *>(value_.ptr()); + auto iter = nb::steal( + PyArray_IterNew(reinterpret_cast(dst_py_array_obj))); + for (auto &cord : *string_array_contents_) { + absl::string_view input_str_view = cord.Flatten(); + auto py_unicode = nb::steal(PyUnicode_FromStringAndSize( + input_str_view.data(), input_str_view.size())); + if (py_unicode.ptr() == nullptr) { + return absl::InternalError("PyUnicode_FromStringAndSize failed"); + } + if (PyArray_SETITEM(dst_py_array_obj, + static_cast(PyArray_ITER_DATA(iter.ptr())), + py_unicode.ptr()) != 0) { + return absl::InternalError("PyArray_SETITEM failed"); + } + PyArray_ITER_NEXT(iter.ptr()); + } + + value_.attr("flags").attr("writeable") = nb::bool_(false); + + string_array_contents_.reset(); + + return absl::OkStatus(); +#else + return absl::FailedPreconditionError( + "String arrays are not supported in this NumPy version."); +#endif +} + +absl::Status PyHostValue::CopyStringArrayToHostAsync( + std::optional &dynamic_shape_holder, ifrt::Array *ifrt_array) { + auto transfer_guard_formatter = [ifrt_array] { + return absl::StrCat( + "shape=(", absl::StrJoin(ifrt_array->shape().dims(), ","), + "), dtype=", ifrt_array->dtype().DebugString(), ", device=", + ifrt_array->sharding().devices()->devices().front()->DebugString()); + }; + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToDeviceToHost(transfer_guard_formatter)); + + TF_ASSIGN_OR_RETURN(nb_dtype dtype, IfrtDtypeToNbDtype(ifrt_array->dtype())); + auto shape = ifrt_array->shape(); + + // Allocate a vector of cords to hold the contents of the array until + // they are until they are ultimately converted to a numpy array as part + // of the `AsNumPyArray` call. + string_array_contents_ = + std::make_shared>(shape.num_elements()); + ready_ = ifrt_array->CopyToHostBuffer(string_array_contents_->data(), + /*byte_strides=*/std::nullopt, + ifrt::ArrayCopySemantics::kAlwaysCopy); + + ready_.OnReady( + [string_array_contents = string_array_contents_](absl::Status) { + }); // Keeps the cords alive until the copy is done. + + return absl::OkStatus(); +} + +absl::Status PyHostValue::CopyToHostAsync( + std::optional &dynamic_shape_holder, ifrt::Array *ifrt_array) { + if (ready_.IsValid()) { + // The array value has been populated, so CopyToHostAsync has been called. + return absl::OkStatus(); + } + + // Copying in Arrays of type kString requires some special handling + if (ifrt_array->dtype().kind() == ifrt::DType::kString) { + return CopyStringArrayToHostAsync(dynamic_shape_holder, ifrt_array); + } + + auto *arr = llvm::dyn_cast_or_null(ifrt_array); + if (arr != nullptr && !arr->pjrt_buffers().front()->IsTuple() && + IsZeroCopyableCpuBuffer(arr->pjrt_buffers().front().get())) { + return absl::OkStatus(); + } + auto transfer_guard_formatter = [ifrt_array] { + return absl::StrCat( + "shape=(", absl::StrJoin(ifrt_array->shape().dims(), ","), + "), dtype=", ifrt_array->dtype().DebugString(), ", device=", + ifrt_array->sharding().devices()->devices().front()->DebugString()); + }; + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToDeviceToHost(transfer_guard_formatter)); + + // TODO(b/182461453): This is a blocking call. If we further implemented + // populating dynamic shape metadata while fetching the literal, we wouldn't + // need this static approach. + const xla::Shape *dynamic_shape; + std::optional shape_holder; + if (llvm::isa(ifrt_array)) { + TF_ASSIGN_OR_RETURN(dynamic_shape, + XlaDynamicShape(ifrt_array, dynamic_shape_holder)); + } else { + // Skip querying the dynamic shape for a non-PjRt Array. + TF_ASSIGN_OR_RETURN(xla::PrimitiveType type, + ifrt::ToPrimitiveType(ifrt_array->dtype())); + shape_holder = ShapeUtil::MakeShapeWithDescendingLayout( + type, ifrt_array->shape().dims()); + dynamic_shape = &*shape_holder; + } + + xla::Shape host_shape = ShapeUtil::DeviceShapeToHostShape(*dynamic_shape); + + auto strides = ByteStridesOrDefaultForShapeInt64(host_shape); + TF_ASSIGN_OR_RETURN(nb_dtype dtype, + PrimitiveTypeToNbDtype(host_shape.element_type())); + value_ = nb_numpy_ndarray(dtype, host_shape.dimensions(), strides); + // TODO(hyeontaek): Several PjRt runtimes assume that the host buffer uses + // the same transposition as the device buffer. This is different from + // PjRtBuffer::ToLiteral()'s semantics that the runtime respects the layout + // of the host buffer literal. On the other hand, the runtime often knows + // better about an efficient layout for the host buffer. It will be useful + // to revisit the semantics of PjRtBuffer::ToLiteral() to see if it is + // desirable for the runtime to choose the layout. + ready_ = ifrt_array->CopyToHostBuffer(value_.mutable_data(), strides, + ifrt::ArrayCopySemantics::kReuseInput); + // Make sure the destination of the copy remains alive until the copy is done. + value_.inc_ref(); + ready_.OnReady([array{value_.ptr()}](absl::Status status) { + GlobalPyRefManager()->AddGarbage(nb::steal(array)); + }); + value_.attr("flags").attr("writeable") = nb::bool_(false); + return absl::OkStatus(); +} + +void PyHostValue::Clear() { + ready_ = {}; + value_ = {}; + string_array_contents_ = {}; +} + +namespace { +PyMemberDef PyBaseArray_members[] = { +#if PY_VERSION_HEX < 0x030C0000 + {"__weaklistoffset__", T_PYSSIZET, + static_cast(offsetof(PyBaseArrayObject, weakrefs)), READONLY, + nullptr}, +#endif // PY_VERSION_HEX < 0x030C0000 + {nullptr, 0, 0, 0, nullptr}, +}; + +PyType_Slot PyBaseArray_slots[] = { + {Py_tp_dealloc, reinterpret_cast(PyBaseArray_tp_dealloc)}, + {Py_tp_members, reinterpret_cast(PyBaseArray_members)}, + {Py_tp_traverse, reinterpret_cast(PyBaseArray_tp_traverse)}, + {Py_tp_hash, reinterpret_cast(PyObject_HashNotImplemented)}, + {0, nullptr}, +}; + +PyGetSetDef PyArray_tp_getset[] = { + {"__dict__", PyObject_GenericGetDict, PyObject_GenericSetDict, nullptr, + nullptr}, + {nullptr, nullptr, nullptr, nullptr, nullptr}, +}; + +PyMemberDef PyArray_members[] = { +#if PY_VERSION_HEX < 0x030C0000 + {"__weaklistoffset__", T_PYSSIZET, + static_cast(offsetof(PyArrayObject, weakrefs)), READONLY, + nullptr}, + {"__dictoffset__", T_PYSSIZET, + static_cast(offsetof(PyArrayObject, dict)), READONLY, nullptr}, +#endif // PY_VERSION_HEX < 0x030C0000 + {nullptr, 0, 0, 0, nullptr}, +}; // namespace xla + +PyType_Slot PyArray_slots[] = { + {Py_tp_new, reinterpret_cast(PyArray_tp_new)}, + {Py_tp_dealloc, reinterpret_cast(PyArray_tp_dealloc)}, + {Py_tp_members, reinterpret_cast(PyArray_members)}, + {Py_tp_traverse, reinterpret_cast(PyArray_tp_traverse)}, + {Py_tp_clear, reinterpret_cast(PyArray_tp_clear)}, + {Py_tp_getset, reinterpret_cast(PyArray_tp_getset)}, + {Py_bf_getbuffer, reinterpret_cast(PyArray_bf_getbuffer)}, + {Py_bf_releasebuffer, reinterpret_cast(PyArray_bf_releasebuffer)}, + {0, nullptr}, +}; + +} // namespace + +absl::Status PyArray::RegisterTypes(nb::module_ &m) { + // We are not using nanobind to avoid having a non-standard metaclass, which + // would make Array incompatible with abc.ABCMeta. + std::string base_name = + absl::StrCat(nb::cast(m.attr("__name__")), ".Array"); + PyType_Spec PyBaseArray_spec = { +#if PY_VERSION_HEX < 0x030B0000 + // Work around for https://github.com/python/cpython/issues/89478 + // CPython 3.10 and earlier assume that the .name value remains alive + // forever. + /*.name=*/strdup(base_name.c_str()), +#else + /*.name=*/base_name.c_str(), +#endif // PY_VERSION_HEX < 0x030B0000 + /*.basicsize=*/static_cast(sizeof(PyBaseArrayObject)), + /*.itemsize=*/0, +#if PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, +#else // PY_VERSION_HEX >= 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_MANAGED_WEAKREF, +#endif // PY_VERSION_HEX >= 0x030C0000 + /*.slots=*/PyBaseArray_slots}; + auto *base_type = PyType_FromSpec(&PyBaseArray_spec); + if (!base_type) { + throw nb::python_error(); + } + m.attr("Array") = nb::borrow(base_type); + + std::string name = + absl::StrCat(nb::cast(m.attr("__name__")), ".ArrayImpl"); + + PyType_Spec PyArray_spec = { +#if PY_VERSION_HEX < 0x030B0000 + // Work around for https://github.com/python/cpython/issues/89478 + // CPython 3.10 and earlier assume that the .name value remains alive + // forever. + /*.name=*/strdup(name.c_str()), +#else + /*.name=*/name.c_str(), +#endif // PY_VERSION_HEX < 0x030B0000 + /*.basicsize=*/static_cast(sizeof(PyArrayObject)), + /*.itemsize=*/0, +#if PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, +#else // PY_VERSION_HEX >= 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_MANAGED_DICT | Py_TPFLAGS_MANAGED_WEAKREF, +#endif // PY_VERSION_HEX >= 0x030C0000 + /*.slots=*/PyArray_slots, + }; + + type_ = PyType_FromSpecWithBases(&PyArray_spec, base_type); + if (!type_) { + throw nb::python_error(); + } + auto type = nb::borrow(type_); + m.attr("ArrayImpl") = type; + + type.attr("__init__") = nb::cpp_function( + [](PyArray self, nb::object aval, nb::object sharding, nb::list arrays, + bool committed, bool skip_checks) { + if (!(arrays.size() == 0 || arrays[0].type().is(PyArray::type()))) { + throw nb::type_error( + absl::StrCat( + "Unsupported type for elements in `arrays`: ", + nb::cast(nb::str(arrays[0].type()))) + .c_str()); + } + auto py_arrays = nb::cast>(arrays); + PyArray::PyInit(self, std::move(aval), std::move(sharding), py_arrays, + committed, skip_checks); + }, + nb::is_method(), nb::arg("aval"), nb::arg("sharding"), nb::arg("arrays"), + nb::arg("committed"), nb::arg("_skip_checks") = false); + type.attr("delete") = nb::cpp_function( + [](PyArray &self) { xla::ThrowIfError(self.Delete()); }, nb::is_method()); + type.attr("_sharding") = nb_property_readonly(&PyArray::sharding); + type.attr("aval") = nb_property(&PyArray::aval, &PyArray::set_aval); + type.attr("_arrays") = + nb_property(&PyArray::arrays, [](PyArray &self, nb::object obj) { + xla::ThrowIfError(self.set_arrays(obj)); + }); + type.attr("_fully_replicated_shard") = nb::cpp_function( + [](PyArray self) { + return xla::ValueOrThrow(self.FullyReplicatedShard()); + }, + nb::is_method()); + type.attr("_npy_value") = + nb_property(&PyArray::npy_value, &PyArray::set_npy_value); + type.attr("_committed") = nb_property_readonly(&PyArray::committed); + type.attr("unsafe_buffer_pointer") = nb::cpp_function( + [](PyArray self) { + return xla::ValueOrThrow(self.UnsafeBufferPointer()); + }, + nb::is_method()); + type.attr("__cuda_array_interface__") = nb_property_readonly( + [](PyArray self) { return self.CudaArrayInterface(); }); + type.attr("_pjrt_layout") = + nb_property_readonly(xla::ValueOrThrowWrapper(&PyArray::layout)); + type.attr("on_device_size_in_bytes") = nb::cpp_function( + xla::ValueOrThrowWrapper(&PyArray::GetOnDeviceSizeInBytes), + nb::is_method()); + type.attr("_single_device_array_to_np_array_did_copy") = nb::cpp_function( + xla::ValueOrThrowWrapper(&PyArray::SingleDeviceArrayToNumpyArrayDidCopy), + nb::is_method()); + type.attr("_copy_single_device_array_to_host_async") = nb::cpp_function( + [](PyArray &self) { + xla::ThrowIfError(self.CopySingleDeviceArrayToHostAsync()); + }, + nb::is_method()); + type.attr("_replace_with") = nb::cpp_function( + [](PyArray &self, PyArray &o) { + xla::ThrowIfError(self.ReplaceWithAlias(o)); + }, + nb::is_method()); + type.attr("block_until_ready") = nb::cpp_function( + [](PyArray self) -> nb::object { + xla::ThrowIfError(self.BlockUntilReady()); + return self; + }, + nb::is_method()); + type.attr("platform") = nb::cpp_function( + [](PyArray self) { + if (self.ifrt_array()->client()->platform_name() == "cuda" || + self.ifrt_array()->client()->platform_name() == "rocm") { + return absl::string_view("gpu"); + } else { + return self.ifrt_array()->client()->platform_name(); + } + }, + nb::is_method()); + type.attr("is_ready") = nb::cpp_function( + [](PyArray self) { return xla::ValueOrThrow(self.IsReady()); }, + nb::is_method()); + type.attr("is_deleted") = + nb::cpp_function(&PyArray::IsDeleted, nb::is_method()); + type.attr("traceback") = nb_property_readonly(&PyArray::traceback); + type.attr("clone") = nb::cpp_function(&PyArray::Clone, nb::is_method()); + type.attr("__module__") = m.attr("__name__"); + + m.attr("batched_copy_array_to_devices_with_sharding") = nb::cpp_function( + [](absl::Span arrays, + absl::Span> dst_device_lists, + absl::Span shardings, + absl::Span array_copy_semantics) { + if (arrays.empty()) { + return std::vector(); + } + auto *client = arrays[0].ifrt_array()->client(); + std::vector device_lists; + device_lists.reserve(dst_device_lists.size()); + for (const auto &dst_devices : dst_device_lists) { + absl::InlinedVector devices; + devices.reserve(dst_devices.size()); + for (auto &d : dst_devices) { + devices.push_back(d->device()); + } + device_lists.push_back(client->MakeDeviceList(devices)); + } + return xla::ValueOrThrow(PyArray::BatchedCopyToDeviceWithSharding( + arrays, device_lists, shardings, array_copy_semantics)); + }); + m.attr("array_result_handler") = nb::cpp_function( + [](nb::object aval, nb::object sharding, bool committed, + bool skip_checks) -> nb_class_ptr { + return make_nb_class( + std::move(aval), std::move(sharding), committed, skip_checks); + }, + nb::arg("aval"), nb::arg("sharding"), nb::arg("committed"), + nb::arg("_skip_checks") = false); + + nb::class_(m, "ResultHandler") + .def("__call__", [](const PyArrayResultHandler &self, + PyArray arg) { return self.Call(arg); }) + .def("__call__", + [](const PyArrayResultHandler &self, + std::vector py_arrays) { return self.Call(py_arrays); }); + + return absl::OkStatus(); +} + +} // namespace xla diff --git a/tests/ci_clangformat/py_array.h b/tests/ci_clangformat/py_array.h new file mode 100644 index 0000000..8aff178 --- /dev/null +++ b/tests/ci_clangformat/py_array.h @@ -0,0 +1,362 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_ARRAY_H_ +#define JAXLIB_PY_ARRAY_H_ + +#include + +#include +#include +#include +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "jaxlib/traceback.h" +#include "nanobind/nanobind.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/shape.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/util.h" +#include "llvm/Support/Casting.h" + +namespace xla { + +// Private to PyArray, but you cannot forward declare member classes. +// Not thread safe; assumes the GIL is held. +class PyHostValue { + public: + PyHostValue(); + ~PyHostValue(); + + PyHostValue(const PyHostValue &) = delete; + PyHostValue(PyHostValue &&) = delete; + PyHostValue &operator=(const PyHostValue &) = delete; + PyHostValue &operator=(PyHostValue &&) = delete; + + absl::Status CopyToHostAsync(std::optional &dynamic_shape_holder, + ifrt::Array *ifrt_array); + + absl::StatusOr> AsNumPyArray( + std::optional &dynamic_shape_holder, ifrt::Array *ifrt_array); + + void Clear(); + + private: + absl::Status CopyStringArrayToHostAsync( + std::optional &dynamic_shape_holder, ifrt::Array *ifrt_array); + + absl::Status ConvertStringArrayContentsToNumpyArray(ifrt::Array *ifrt_array); + + ifrt::Future<> ready_; + nb_numpy_ndarray value_; + + // Optional field, only used for arrays of type kString. This vector of cords + // serves as input buffer for the CopyToHostBuffer call. It holds these + // contents until it is lazily converted it to a numpy array when the user + // calls `AsNumPyArray`. + std::shared_ptr> string_array_contents_; +}; + +// Private to PyArray, but you cannot forward declare member classes. +struct PyArray_Storage { + PyArray_Storage(nanobind::object aval, bool weak_type, nb_dtype dtype, + std::vector shape, nanobind::object sharding, + bool committed, nb_class_ptr py_client, + std::optional traceback, + ifrt::ArrayRef ifrt_array, xla::PjRtFuture<> result_status); + + ~PyArray_Storage(); + nanobind::handle AsHandle(); + + nanobind::object aval; + bool weak_type = false; + nb_dtype dtype; + std::vector shape; + + nanobind::object sharding; + nanobind::object npy_value = nanobind::none(); + bool committed = false; + + nb_class_ptr py_client; + std::optional traceback; + ifrt::ArrayRef ifrt_array; + nanobind::object fully_replicated_array = nanobind::none(); + + // optional field, used only in python + std::vector py_arrays; + PyHostValue host_value; // Protected by the GIL. + std::optional dynamic_shape = std::nullopt; + // Only set if this Array was generated by a computation that has effects. + // This is the result status of the XLA computation that generated this + // array. + xla::PjRtFuture<> result_status; + + // Doubly-linked list of all PyArrays known to the client. Protected by the + // GIL. Since multiple PyArrays may share the same PjRtBuffer, there may be + // duplicate PjRtBuffers in this list. + PyArray_Storage *next; + PyArray_Storage *prev; + + uint8_t thread_id_bucket; +}; + +// The C++ implementation of jax.Array. A few key methods and data members are +// implemented in C++ for performance, while most of the functionalities are +// still implemented in python. +class PyArray : public nanobind::object { + public: + NB_OBJECT(PyArray, nanobind::object, "Array", PyArray::IsPyArray); + PyArray() = default; + + // "__init__" methods. Only used in python + static void PyInit(PyArray self, nanobind::object aval, + nanobind::object sharding, + absl::Span py_arrays, bool committed, + bool skip_checks); + + // Only used in C++. `skip_checks` should only be set for Arrays created by + // jax that cannot possibly have consistency issues (e.g. `sharding` devices + // different than `ifrt_array` devices). Arrays created by users should be + // checked. + PyArray(nanobind::object aval, bool weak_type, nb_dtype dtype, + std::vector shape, nanobind::object sharding, + nb_class_ptr py_client, + std::optional traceback, ifrt::ArrayRef ifrt_array, + bool committed, bool skip_checks, + xla::PjRtFuture<> result_status = xla::PjRtFuture<>()); + + static PyArray MakeFromSingleDeviceArray( + nb_class_ptr py_client, std::optional traceback, + ifrt::ArrayRef ifrt_array, bool weak_type, bool committed, + xla::PjRtFuture<> result_status = xla::PjRtFuture<>()); + + static PyArray MakeFromIfrtArrayAndSharding( + nb_class_ptr py_client, std::optional traceback, + ifrt::ArrayRef ifrt_array, nanobind::object sharding, bool weak_type, + bool committed, bool skip_checks); + + static absl::Status RegisterTypes(nanobind::module_ &m); + + static PyArray borrow(PyObject *ptr) { + return nanobind::borrow(ptr); + } + + using Storage = PyArray_Storage; + + const nanobind::object &aval() const { return GetStorage().aval; } + void set_aval(nanobind::object aval) { GetStorage().aval = std::move(aval); } + + bool weak_type() const { return GetStorage().weak_type; } + + const nb_dtype &dtype() const { return GetStorage().dtype; } + absl::Span shape() const { return GetStorage().shape; } + + const nanobind::object &sharding() const { return GetStorage().sharding; } + + absl::StatusOr> layout() { + return ifrt_array()->layout(); + } + + bool committed() const { return GetStorage().committed; } + + const nanobind::object &npy_value() const { return GetStorage().npy_value; } + void set_npy_value(nanobind::object v) { + GetStorage().npy_value = std::move(v); + } + + const nb_class_ptr &py_client() const { + return GetStorage().py_client; + } + + const std::optional &traceback() const { + return GetStorage().traceback; + } + + // Returns xla::InvalidArgument if the buffer has been deleted. + // See `PjRtFuture` for the semantics of `IsReady` and `IsKnownReady`. + absl::StatusOr IsReady() { + ifrt::Array *ifrt_array_ptr = ifrt_array(); + if (ifrt_array_ptr->IsDeleted()) { + return InvalidArgument("Array has been deleted."); + } + return ifrt_array_ptr->GetReadyFuture().IsReady(); + } + + const xla::PjRtFuture<> &result_status() const { + return GetStorage().result_status; + } + + ifrt::Array *ifrt_array() const { return GetStorage().ifrt_array.get(); } + + // Short-term escape hatch to get PjRtBuffers from PyArray. + // TODO(hyeontaek): Migrate all users of this method to be agnostic of PjRt. + absl::Span> pjrt_buffers() const { + ifrt::Array *ifrt_array_ptr = ifrt_array(); + if (ifrt_array_ptr == nullptr) { + return {}; + } + auto *arr = + llvm::dyn_cast_or_null(ifrt_array_ptr); + if (arr == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return arr->pjrt_buffers(); + } + + int num_addressable_shards() const { + ifrt::Array *ifrt_array_ptr = ifrt_array(); + if (ifrt_array_ptr == nullptr) { + return 0; + } + auto *arr = + llvm::dyn_cast_or_null(ifrt_array_ptr); + if (arr == nullptr) { + // TODO(hyeontaek): Add num_addressable_shards to ifrt. + return num_shards(); + } + return arr->pjrt_buffers().size(); + } + + std::vector &py_arrays() { return GetStorage().py_arrays; } + const std::vector &py_arrays() const { + return GetStorage().py_arrays; + } + const std::vector &py_arrays_cached(); + + nanobind::object arrays(); + absl::Status set_arrays(nanobind::object obj); + absl::StatusOr FullyReplicatedShard(); + + int num_shards() const { + ifrt::Array *ifrt_array_ptr = ifrt_array(); + if (ifrt_array_ptr == nullptr) { + return 0; + } + return ifrt_array_ptr->sharding().devices()->size(); + } + + static nanobind::handle type() { + DCHECK(type_); + return nanobind::handle(type_); + } + + static bool IsPyArray(nanobind::handle arg) { + return arg.type().is(PyArray::type()); + } + + absl::Status BlockUntilReady() const; + + absl::Status BlockUntilResultStatusIsReady(); + + absl::StatusOr GetOnDeviceSizeInBytes(); + absl::StatusOr> + SingleDeviceArrayToNumpyArrayDidCopy(); + absl::StatusOr SingleDeviceArrayToNumpyArray(); + absl::Status CopySingleDeviceArrayToHostAsync(); + nanobind::dict CudaArrayInterface(); + absl::StatusOr UnsafeBufferPointer(); + + absl::Status Delete(); + + bool IsDeleted() const; + + PyArray Clone() const; + + static absl::StatusOr> BatchedCopyToDeviceWithSharding( + absl::Span py_arrays, + absl::Span dst_device_lists, + absl::Span dst_shardings, + absl::Span array_copy_semantics); + + static absl::StatusOr BatchedDevicePut( + nanobind::object aval, nanobind::object sharding, + std::vector xs, + absl::Span dst_devices, bool committed, + bool force_copy, PjRtClient::HostBufferSemantics host_buffer_semantics, + bool jax_enable_x64); + + static absl::StatusOr ReorderShards( + PyArray x, nanobind::object dst_sharding, + ifrt::ArrayCopySemantics array_copy_semantics); + + static absl::Status BatchedBlockUntilReady( + std::vector objs); + + absl::Status ReplaceWithAlias(PyArray o); + + private: + absl::StatusOr AssertUnsharded(absl::string_view api); + + nanobind::object CheckAndRearrange(absl::Span py_arrays, + nanobind::object sharding, + nanobind::object aval); + + void SetIfrtArray(ifrt::ArrayRef ifrt_array); + + Storage &GetStorage(); + const Storage &GetStorage() const; + + inline static PyObject *type_ = nullptr; +}; + +class PyArrayResultHandler { + public: + PyArrayResultHandler(nanobind::object aval, nanobind::object sharding, + bool committed, bool skip_checks); + + PyArray Call(absl::Span py_arrays) const; + PyArray Call(PyArray py_array) const; + + PyArray Call(nb_class_ptr py_client, ifrt::ArrayRef ifrt_array, + xla::PjRtFuture<> result_status = xla::PjRtFuture<>()) const; + + private: + nanobind::object aval_; + nanobind::object sharding_; + bool weak_type_; + bool committed_; + bool skip_checks_; + + nb_dtype dtype_; + std::vector shape_; +}; + +absl::StatusOr CudaArrayInterfaceToBuffer( + const nanobind::dict &cai, nb_class_ptr cuda_client, + std::optional device_id); + +} // namespace xla + +#endif // JAXLIB_PY_ARRAY_H_ diff --git a/tests/ci_clangformat/py_client.cc b/tests/ci_clangformat/py_client.cc new file mode 100644 index 0000000..ff60c81 --- /dev/null +++ b/tests/ci_clangformat/py_client.cc @@ -0,0 +1,1021 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_client.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "jaxlib/guard_lib.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_device.h" +#include "jaxlib/py_device_list.h" +#include "jaxlib/py_executable.h" +#include "jaxlib/py_host_callback.h" +#include "jaxlib/py_memory_space.h" +#include "jaxlib/py_values.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/sharding.h" +#include "jaxlib/traceback.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/Pass/PassManager.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/literal.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/hlo/hlo_program.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/program.h" +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "xla/python/pprof_profile_builder.h" +#include "xla/python/types.h" +#include "xla/python/version.h" +#include "xla/service/platform_util.h" // IWYU pragma: keep +#include "xla/shape.h" +#include "xla/status_macros.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "llvm/Support/Casting.h" + +namespace xla { + +namespace nb = nanobind; + +/*static*/ nb_class_ptr PyClient::Make( + std::shared_ptr ifrt_client) { + auto client = make_nb_class(std::move(ifrt_client)); + Initialize(client); + return client; +} + +PyClient::PyClient(std::shared_ptr ifrt_client) + : ifrt_client_(std::move(ifrt_client)), + client_attributes_(ifrt_client_->Attributes()) { + CHECK(ifrt_client_); +} + +/* static */ void PyClient::Initialize(nb_class_ptr client) { + for (ifrt::Device *device : client->ifrt_client()->devices()) { + client->devices_[device] = make_nb_class(client, device); + + for (ifrt::Memory *memory : device->Memories()) { + auto &py_memory = client->memory_spaces_[memory]; + if (py_memory.get() == nullptr) { + py_memory = make_nb_class(client, memory); + } + } + } +} + +PyClient::~PyClient() { + nb::gil_scoped_release gil; + ifrt_client_ = nullptr; +} + +nb_class_ptr PyClient::GetPyDevice(ifrt::Device *device) { + auto &py_device = devices_[device]; + if (py_device.get() == nullptr) { + py_device = make_nb_class( + nb::borrow>(nb::find(this)), device); + } + return py_device; +} + +nb_class_ptr PyClient::GetPyMemorySpace( + ifrt::Memory *memory_space) { + auto &py_memory = memory_spaces_[memory_space]; + if (py_memory.get() == nullptr) { + py_memory = make_nb_class( + nb::borrow>(nb::find(this)), memory_space); + } + return py_memory; +} + +std::vector> PyClient::Devices() { + std::vector> devices; + auto span = ifrt_client_->devices(); + devices.reserve(span.size()); + for (ifrt::Device *device : span) { + devices.push_back(GetPyDevice(device)); + } + return devices; +} + +std::vector> PyClient::LocalDevices() { + std::vector> devices; + devices.reserve(ifrt_client_->addressable_devices().size()); + for (ifrt::Device *device : ifrt_client_->addressable_devices()) { + devices.push_back(GetPyDevice(device)); + } + return devices; +} + +std::vector> PyClient::GetAllDevices() { + std::vector> devices; + devices.reserve(ifrt_client_->GetAllDevices().size()); + for (ifrt::Device *device : ifrt_client_->GetAllDevices()) { + devices.push_back(GetPyDevice(device)); + } + return devices; +} + +absl::StatusOr> PyClient::DeviceFromLocalHardwareId( + int local_hardware_id) { + TF_ASSIGN_OR_RETURN(ifrt::Device * device, + ifrt_client_->LookupAddressableDevice(local_hardware_id)); + return GetPyDevice(device); +} + +nb::list PyClient::LiveExecutables() { + CHECK(PyGILState_Check()); + nb::ft_lock_guard lock(executables_mutex_); + nb::list executables; + for (PyLoadedExecutable *exec = executables_; exec; exec = exec->next_) { + executables.append(nb::find(exec)); + } + return executables; +} + +absl::Status PyClient::Defragment() { + CHECK(PyGILState_Check()); + if (!llvm::isa(ifrt_client_.get())) { + return absl::UnimplementedError( + "Defragmentation is not supported on this runtime."); + } + ifrt::PlatformId platform_id = ifrt_client_->platform_id(); + bool is_gpu_client = platform_id == CudaId() || platform_id == RocmId() || + platform_id == SyclId(); + + if (!is_gpu_client) { + return absl::UnimplementedError( + "Defragmentation is not supported on this runtime."); + } + + // TODO(b/399879011): This is a GPU-specific implementation of `Defragment`. + // Ideally, this would be replaced with some kind of auto-defrag-on-OOM, or at + // least would not live in this file. + + struct TmpBuffer { + // Non-empty for buffers found in a PyArray_Storage. Multiple Arrays + // can reference the same PjRtBuffer. + std::vector *> pjrt_buffer_ptrs; + // TODO(skyewm): maybe use py_buffer's HostValue + std::shared_ptr host_copy; + }; + + // Synchronously copy all buffers to host + absl::flat_hash_map pjrt_buf_to_tmp_buffer; + + std::vector arrays = LiveArrays(); + for (const PyArray &array : arrays) { + // TODO(hyeontaek): Support non-PjRt Arrays. + // TODO(hyeontaek): Re-construct ifrt::Array with new PjRtBuffer so that + // std::shared_ptr does not need to be updated in-place. + if (array.ifrt_array() == nullptr) { + continue; + } + auto *arr = + llvm::dyn_cast_or_null(array.ifrt_array()); + if (arr == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend " + "only."); + } + TF_ASSIGN_OR_RETURN(absl::Span> pjrt_buffers, + arr->mutable_pjrt_buffers()); + for (int i = 0; i < pjrt_buffers.size(); ++i) { + std::shared_ptr &pjrt_buf_ptr = pjrt_buffers[i]; + if (pjrt_buf_ptr->IsDeleted()) { + continue; + } + auto [iter, inserted] = + pjrt_buf_to_tmp_buffer.insert({pjrt_buf_ptr.get(), TmpBuffer()}); + if (inserted) { + TF_ASSIGN_OR_RETURN(iter->second.host_copy, + pjrt_buf_ptr->ToLiteralSync()); + } + iter->second.pjrt_buffer_ptrs.push_back(&pjrt_buf_ptr); + } + } + + // All buffers successfully copied to host, delete on-device copies. + // + // Use blocking delete operation to ensure all memory is actually cleared + // before we start rewriting buffers. + // + // Die instead of returning a bad status because program presumably can't + // continue if we fail to reconstitute device buffers. + for (const auto &it : pjrt_buf_to_tmp_buffer) { + PjRtBuffer *pjrt_buf = it.first; + TF_CHECK_OK(pjrt_buf + ->ReleaseDeviceMemoryOwnership( + /*wait_for_operations_to_complete=*/true) + .status()); + } + + // Copy host copies back to device and update PyArrays in-place. + for (auto &it : pjrt_buf_to_tmp_buffer) { + PjRtBuffer *pjrt_buf = it.first; + TmpBuffer &tmp_buffer = it.second; + std::unique_ptr new_copy = + pjrt_client() + ->BufferFromHostLiteral(*tmp_buffer.host_copy, + pjrt_buf->memory_space()) + .value(); + TF_CHECK_OK(new_copy->GetReadyFuture().Await()); + + std::shared_ptr new_pjrt_buf_ptr(new_copy.release()); + for (std::shared_ptr *pjrt_buffer_ptr : + tmp_buffer.pjrt_buffer_ptrs) { + *pjrt_buffer_ptr = new_pjrt_buf_ptr; + } + } + + // TODO(skyewm): delete executables? + return absl::OkStatus(); +} + +/* static */ absl::StatusOr PyClient::BufferFromPyval( + nb_class_ptr client, nb::handle argument, ifrt::Device *device, + bool force_copy, ifrt::Client::HostBufferSemantics host_buffer_semantics) { + if (device == nullptr) { + TF_RET_CHECK(!client->ifrt_client_->addressable_devices().empty()); + device = client->ifrt_client_->addressable_devices().front(); + } + CHECK(device != nullptr); + + auto transfer_guard_formatter = [&argument, dst_device = device] { + auto type = nb::cast(nb::str(argument.type())); + // Catch exceptions because shape and dtype properties convertible to str + // are not guaranteed to present in an arbitrary argument. + std::string shape; + std::string dtype; + try { + shape = + nb::cast(nb::str(nb::object(argument.attr("shape")))); + } catch (const std::exception &e) { + shape = ""; + } + try { + dtype = + nb::cast(nb::str(nb::object(argument.attr("dtype")))); + } catch (const std::exception &e) { + dtype = ""; + } + return absl::StrCat("type=", type, ", shape=", shape, ", dtype=", dtype, + ", dst_device=", dst_device->DebugString()); + }; + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToHostToDevice(transfer_guard_formatter)); + + TF_ASSIGN_OR_RETURN(ifrt::Device * found_device, + client->ifrt_client_->LookupDevice(device->Id())); + if (found_device != device) { + return InvalidArgument("Cannot copy value to device '%s' with '%s' backend", + device->DebugString(), + client->ifrt_client_->platform_name()); + } + GlobalPyRefManager()->CollectGarbage(); + + DevicePutOptions options; + options.squash_64bit_types = false; + options.allow_zero_copy = + (!force_copy && (host_buffer_semantics == + ifrt::Client::HostBufferSemantics::kImmutableZeroCopy)); + TF_ASSIGN_OR_RETURN(DevicePutResult device_put_result, + DevicePutWithDevice(argument, client->ifrt_client_.get(), + device, ifrt::MemoryKind(), options)); + auto sharding = make_nb_class( + client, client->ifrt_client()->MakeDeviceList({device}), + /*memory_kind=*/nb::none()); + + auto traceback = Traceback::Get(); + return PyArray::MakeFromIfrtArrayAndSharding( + std::move(client), std::move(traceback), + std::move(device_put_result.ifrt_array), std::move(sharding), + /*weak_type=*/false, /*committed=*/false, + /*skip_checks=*/true); +} + +namespace { + +// Makes IFRT `CompileOptions` from XLA `CompileOptions` and optional host +// callbacks. +std::unique_ptr MakeIfrtCompileOptions( + CompileOptions options, ifrt::DeviceListRef executable_devices, + std::vector host_callbacks) { + std::vector> + ifrt_loaded_host_callbacks; + ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); + // Extract `ifrt::LoadedHostCallback`s from host callback capsules that were + // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()`. + for (auto &host_callback : host_callbacks) { + ifrt_loaded_host_callbacks.push_back(tsl::FormRef( + static_cast(host_callback.data()))); + } +#if JAX_IFRT_VERSION_NUMBER >= 6 + return std::make_unique( + std::move(options), std::move(executable_devices), + std::move(ifrt_loaded_host_callbacks)); +#else + return std::make_unique( + std::move(options), std::move(ifrt_loaded_host_callbacks)); +#endif +} + +// Makes IFRT `DeserializeExecutableOptions` from XLA `CompileOptions` and +// optional host callbacks. +std::unique_ptr +MakeIfrtDeserializeExecutableOptions(std::optional options, + ifrt::DeviceListRef executable_devices, + std::vector host_callbacks) { + std::vector> + ifrt_loaded_host_callbacks; + ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); + // Extract `ifrt::LoadedHostCallback`s from host callback capsules that were + // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()`. + for (auto &host_callback : host_callbacks) { + ifrt_loaded_host_callbacks.push_back(tsl::FormRef( + static_cast(host_callback.data()))); + } +#if JAX_IFRT_VERSION_NUMBER >= 6 + return std::make_unique( + std::move(options), std::move(executable_devices), + std::move(ifrt_loaded_host_callbacks)); +#else + return std::make_unique( + std::move(options), std::move(ifrt_loaded_host_callbacks)); +#endif +} + +} // namespace + +/* static */ absl::StatusOr> +PyClient::CompileAndLoadIfrtProgram( + nb_class_ptr client, std::unique_ptr ifrt_program, + std::unique_ptr ifrt_options) { + auto *pjrt_compatible_client = + llvm::dyn_cast_or_null( + client->ifrt_client_.get()); + auto *ifrt_xla_options = + llvm::dyn_cast_or_null(ifrt_options.get()); + // For XLA programs, pass allocated device memory size to compile options for + // pjrt compatible backends. + if (pjrt_compatible_client != nullptr && ifrt_xla_options != nullptr) { + xla::CompileOptions &options = ifrt_xla_options->compile_options; + auto addressable_devices = + pjrt_compatible_client->pjrt_client()->addressable_devices(); + if (!addressable_devices.empty()) { + int device_ordinal = options.executable_build_options.device_ordinal(); + if (device_ordinal < 0) { + device_ordinal = 0; + } + CHECK_LT(device_ordinal, addressable_devices.size()); + auto stats = addressable_devices[device_ordinal]->GetAllocatorStats(); + if (stats.ok() && stats->bytes_limit) { + options.executable_build_options.set_device_memory_size( + *stats->bytes_limit); + } + } + + if (pjrt_compatible_client->pjrt_client()->key_value_store().has_value()) { + options.executable_build_options.set_key_value_store( + *pjrt_compatible_client->pjrt_client()->key_value_store()); + } + } + + ifrt::LoadedExecutableRef ifrt_loaded_executable; + std::optional fingerprint; + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN( + ifrt_loaded_executable, + client->ifrt_client_->GetDefaultCompiler()->CompileAndLoad( + std::move(ifrt_program), std::move(ifrt_options))); + TF_RETURN_IF_ERROR(ifrt_loaded_executable->GetReadyFuture().Await()); + TF_ASSIGN_OR_RETURN(fingerprint, ifrt_loaded_executable->Fingerprint()); + } + auto traceback = Traceback::Get(); + return make_nb_class( + std::move(client), std::move(ifrt_loaded_executable), + std::move(traceback), std::move(fingerprint)); +} + +/* static */ absl::StatusOr> +PyClient::CompileAndLoad(nb_class_ptr client, std::string mlir_module, + ifrt::DeviceListRef executable_devices, + CompileOptions options, + std::vector host_callbacks) { + mlir::MLIRContext context; + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + if (options.executable_build_options.use_shardy_partitioner()) { + // Since Shardy is located in the middle of the XLA pipeline, we need to + // export it before going to HLO while preserving Shardy ops and attrs. + TF_RETURN_IF_ERROR(ExportShardyForHloRoundTrip(*module)); + } + return CompileAndLoadIfrtProgram( + client, std::make_unique(module.get()), + MakeIfrtCompileOptions(std::move(options), std::move(executable_devices), + std::move(host_callbacks))); +} + +/* static */ absl::StatusOr> +PyClient::CompileAndLoad(nb_class_ptr client, std::string mlir_module, + ifrt::DeviceListRef executable_devices, + CompileOptions options, + std::vector host_callbacks) { + mlir::MLIRContext context; + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + if (options.executable_build_options.use_shardy_partitioner()) { + // Since Shardy is located in the middle of the XLA pipeline, we need to + // export it before going to HLO while preserving Shardy ops and attrs. + TF_RETURN_IF_ERROR(ExportShardyForHloRoundTrip(*module)); + } + + std::vector> + ifrt_loaded_host_callbacks; + ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); + // Extract `ifrt::LoadedHostCallback`s from host callback capsules that were + // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()`. + for (auto &host_callback : host_callbacks) { + auto callback = tsl::MakeRef( + client->ifrt_client(), std::move(host_callback)); + ifrt_loaded_host_callbacks.push_back(callback); + } +#if JAX_IFRT_VERSION_NUMBER >= 6 + auto compile_options = std::make_unique( + std::move(options), std::move(executable_devices), + std::move(ifrt_loaded_host_callbacks)); +#else + auto compile_options = std::make_unique( + std::move(options), std::move(ifrt_loaded_host_callbacks)); +#endif + return CompileAndLoadIfrtProgram( + client, std::make_unique(module.get()), + std::move(compile_options)); +} + +absl::StatusOr PyClient::SerializeExecutable( + const PyLoadedExecutable &executable) const { + TF_ASSIGN_OR_RETURN(auto serialized, + executable.ifrt_loaded_executable()->Serialize()); + return nb::bytes(serialized.data(), serialized.size()); +} + +/* static */ absl::StatusOr> +PyClient::DeserializeExecutable(nb_class_ptr client, + nb::bytes serialized, + ifrt::DeviceListRef executable_devices, + std::optional options, + std::vector host_callbacks) { + ifrt::LoadedExecutableRef ifrt_loaded_executable; + std::optional fingerprint; + auto ifrt_deserialize_options = MakeIfrtDeserializeExecutableOptions( + std::move(options), std::move(executable_devices), + std::move(host_callbacks)); + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN( + ifrt_loaded_executable, + client->ifrt_client_->GetDefaultCompiler()->DeserializeLoadedExecutable( + absl::string_view(serialized.c_str(), serialized.size()), + std::move(ifrt_deserialize_options))); + } + TF_ASSIGN_OR_RETURN(fingerprint, ifrt_loaded_executable->Fingerprint()); + auto traceback = Traceback::Get(); + return make_nb_class( + std::move(client), std::move(ifrt_loaded_executable), + std::move(traceback), std::move(fingerprint)); +} + +namespace { + +struct HeapProfileKey { + Traceback *traceback; + int64_t size; + xla::PjRtDevice *device; + bool operator==(const HeapProfileKey &other) const; +}; + +bool HeapProfileKey::operator==(const HeapProfileKey &other) const { + if (size != other.size || device != other.device) { + return false; + } + if ((traceback == nullptr) != (other.traceback == nullptr)) { + return false; + } + if (traceback && traceback->raw_frames() != other.traceback->raw_frames()) { + return false; + } + return true; +} + +template +H AbslHashValue(H h, const HeapProfileKey &key) { + if (key.traceback) { + h = H::combine(std::move(h), key.traceback->raw_frames()); + } + h = H::combine(std::move(h), key.size, key.device); + return h; +} + +} // namespace + +absl::StatusOr PyClient::HeapProfile() { + CHECK(PyGILState_Check()); + absl::flat_hash_set buffer_set; + absl::flat_hash_map entries; + + auto add_buffer_to_profile = [&](PjRtBuffer *buffer, Traceback *traceback) { + // We only wish to count each PjRtBuffer once, even though they may be + // shared by multiple PyArrays. + if (!buffer->IsDeleted() && buffer_set.insert(buffer).second) { + TF_ASSIGN_OR_RETURN(size_t size, buffer->GetOnDeviceSizeInBytes()); + HeapProfileKey key{traceback, static_cast(size), + buffer->device()}; + ++entries[key]; + } + return absl::OkStatus(); + }; + + std::vector arrays = LiveArrays(); + for (const PyArray &array : arrays) { + if (array.ifrt_array() == nullptr) { + continue; + } + auto *arr = + llvm::dyn_cast_or_null(array.ifrt_array()); + // TODO(hyeontaek): Support non-PjRt Arrays. + if (arr == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend " + "only."); + } + for (const auto &buffer : arr->pjrt_buffers()) { + TF_RETURN_IF_ERROR(add_buffer_to_profile( + buffer.get(), + array.traceback() ? array.traceback()->get() : nullptr)); + } + } + + for (PyLoadedExecutable *executable = executables_; executable; + executable = executable->next_) { + HeapProfileKey key{ + executable->traceback() ? executable->traceback()->get() : nullptr, + executable->SizeOfGeneratedCodeInBytes(), nullptr}; + ++entries[key]; + } + + PprofProfileBuilder builder; + auto *allocations = builder.profile().add_sample_type(); + allocations->set_type(builder.StringId("allocations")); + allocations->set_unit(builder.StringId("count")); + auto *space = builder.profile().add_sample_type(); + space->set_type(builder.StringId("space")); + space->set_unit(builder.StringId("bytes")); + + const int kind_string_id = builder.StringId("kind"); + const int buffer_string_id = builder.StringId("buffer"); + const int executable_string_id = builder.StringId("executable"); + const int device_string_id = builder.StringId("device"); + for (const auto &entry : entries) { + auto *sample = builder.profile().add_sample(); + if (entry.first.traceback) { + for (const auto &frame : entry.first.traceback->raw_frames()) { + sample->add_location_id(builder.LocationId(frame.first, frame.second)); + } + } + sample->add_value(entry.second); + sample->add_value(entry.first.size * entry.second); + + auto *kind_label = sample->add_label(); + kind_label->set_key(kind_string_id); + if (entry.first.device) { + kind_label->set_str(buffer_string_id); + auto *device_label = sample->add_label(); + device_label->set_key(device_string_id); + std::string device_label_str(entry.first.device->DebugString()); + device_label->set_str(builder.StringId(device_label_str)); + } else { + kind_label->set_str(executable_string_id); + } + } + std::string serialized = builder.profile().SerializeAsString(); + return nb::bytes(serialized.data(), serialized.size()); +} + +absl::StatusOr PyClient::MakePythonCallbackUsingHostSendAndRecv( + nb::callable callable, absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, nb::callable serializer) { + TF_ASSIGN_OR_RETURN( + auto loaded_host_callback, + PyHostSendAndRecvLoadedHostCallback::Create( + ifrt_client(), std::move(callable), operand_shapes, result_shapes, + send_channel_ids, recv_channel_ids, std::move(serializer))); + nb::capsule callback_capsule( + loaded_host_callback.release(), [](void *ptr) noexcept { + static_cast(ptr)->DropRef(); + }); + return callback_capsule; +} + +/* static */ int PyClient::tp_traverse(PyObject *self, visitproc visit, + void *arg) { + Py_VISIT(Py_TYPE(self)); + if (!nb::inst_ready(self)) { + return 0; + } + PyClient *c = nb::inst_ptr(self); + for (const auto &[ifrt_device, py_device] : c->devices_) { + Py_VISIT(py_device.ptr()); + } + for (const auto &[ifrt_memory, py_memory] : c->memory_spaces_) { + Py_VISIT(py_memory.ptr()); + } + return 0; +} + +/* static */ int PyClient::tp_clear(PyObject *self) { + PyClient *c = nb::inst_ptr(self); + absl::flat_hash_map> devices; + std::swap(devices, c->devices_); + absl::flat_hash_map> + memory_spaces; + std::swap(memory_spaces, c->memory_spaces_); + return 0; +} + +PyType_Slot PyClient::slots_[] = { + {Py_tp_traverse, (void *)PyClient::tp_traverse}, + {Py_tp_clear, (void *)PyClient::tp_clear}, + {0, nullptr}, +}; + +/* static */ void PyClient::RegisterPythonTypes(nb::module_ &m) { + nb::enum_(m, "HostBufferSemantics") + .value("IMMUTABLE_ONLY_DURING_CALL", + PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall) + .value("IMMUTABLE_UNTIL_TRANSFER_COMPLETES", + PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes) + .value("ZERO_COPY", PjRtClient::HostBufferSemantics::kImmutableZeroCopy); + + nb::class_ py_local_client(m, "Client", nb::is_weak_referenceable(), + nb::type_slots(PyClient::slots_)); + py_local_client.def_prop_ro("platform", &PyClient::platform_name) + .def_prop_ro("_raw_platform", &PyClient::raw_platform_name) + .def_prop_ro("platform_version", &PyClient::platform_version) + .def_prop_ro("runtime_type", &PyClient::runtime_type) + .def("device_count", &PyClient::device_count) + .def("local_device_count", &PyClient::addressable_device_count) + .def("devices", &PyClient::Devices) + .def("local_devices", &PyClient::LocalDevices) + // TODO(hyeontaek): Remove this method once we have a unified API for + // enumerating devices with different criteria. + .def("_get_all_devices", &PyClient::GetAllDevices) + .def("device_from_local_hardware_id", + xla::ValueOrThrowWrapper(&PyClient::DeviceFromLocalHardwareId)) + .def("live_executables", &PyClient::LiveExecutables) + .def("live_arrays", &PyClient::LiveArrays) + .def("live_buffers", &PyClient::LiveArrays) + .def("process_index", &PyClient::process_index) + .def("host_id", &PyClient::process_index) + .def("task_id", &PyClient::process_index) + .def( + "buffer_from_pyval", + [](nb_class_ptr client, nb::handle argument, + PyDevice *device, bool force_copy, + PjRtClient::HostBufferSemantics host_buffer_semantics) { + return ValueOrThrow( + PyClient::BufferFromPyval(std::move(client), argument, + device ? device->device() : nullptr, + force_copy, host_buffer_semantics)); + }, + nb::arg("argument"), nb::arg("device").none() = nullptr, + nb::arg("force_copy") = false, + nb::arg("host_buffer_semantics") = + PjRtClient::HostBufferSemantics::kImmutableZeroCopy) + .def( + "compile", + [](nb_class_ptr client, nb::bytes mlir_module, + jax::PyDeviceList &py_executable_devices, CompileOptions options, + std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(py_executable_devices.ifrt_device_list()); + return ValueOrThrow(PyClient::CompileAndLoad( + std::move(client), + std::string(mlir_module.c_str(), mlir_module.size()), + std::move(executable_devices), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def( + "compile", + [](nb_class_ptr client, nb::bytes mlir_module, + jax::PyDeviceList &py_executable_devices, CompileOptions options, + std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(py_executable_devices.ifrt_device_list()); + return ValueOrThrow(PyClient::CompileAndLoad( + std::move(client), + std::string(mlir_module.c_str(), mlir_module.size()), + std::move(executable_devices), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def( + "compile", + [](nb_class_ptr client, std::string mlir_module, + jax::PyDeviceList &py_executable_devices, CompileOptions options, + std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(py_executable_devices.ifrt_device_list()); + return ValueOrThrow(PyClient::CompileAndLoad( + std::move(client), std::move(mlir_module), + std::move(executable_devices), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def( + "compile", + [](nb_class_ptr client, std::string mlir_module, + jax::PyDeviceList &py_executable_devices, CompileOptions options, + std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(py_executable_devices.ifrt_device_list()); + return ValueOrThrow(PyClient::CompileAndLoad( + std::move(client), std::move(mlir_module), + std::move(executable_devices), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + // The following two overloads are for users of deprecated APIs who call + // `backend.compile` but do not have visibility to `DeviceList`. + .def( + "compile", + [](nb_class_ptr client, nb::bytes mlir_module, + nb::sequence &py_executable_devices, CompileOptions options) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(jax::PyDeviceList(nb::tuple(py_executable_devices)) + .ifrt_device_list()); + return ValueOrThrow(PyClient::CompileAndLoad( + std::move(client), + std::string(mlir_module.c_str(), mlir_module.size()), + std::move(executable_devices), std::move(options), + std::vector())); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions()) + .def( + "compile", + [](nb_class_ptr client, std::string mlir_module, + nb::sequence &py_executable_devices, CompileOptions options) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(jax::PyDeviceList(nb::tuple(py_executable_devices)) + .ifrt_device_list()); + return ValueOrThrow(PyClient::CompileAndLoad( + std::move(client), std::move(mlir_module), + std::move(executable_devices), std::move(options), + std::vector())); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions()) + .def( + "compile_and_load", + [](nb_class_ptr client, nb::bytes mlir_module, + jax::PyDeviceList &py_executable_devices, CompileOptions options, + std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(py_executable_devices.ifrt_device_list()); + return ValueOrThrow(PyClient::CompileAndLoad( + std::move(client), + std::string(mlir_module.c_str(), mlir_module.size()), + std::move(executable_devices), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def( + "compile_and_load", + [](nb_class_ptr client, nb::bytes mlir_module, + jax::PyDeviceList &py_executable_devices, CompileOptions options, + std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(py_executable_devices.ifrt_device_list()); + return ValueOrThrow(PyClient::CompileAndLoad( + std::move(client), + std::string(mlir_module.c_str(), mlir_module.size()), + std::move(executable_devices), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def( + "compile_and_load", + [](nb_class_ptr client, std::string mlir_module, + jax::PyDeviceList &py_executable_devices, CompileOptions options, + std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(py_executable_devices.ifrt_device_list()); + return ValueOrThrow(PyClient::CompileAndLoad( + std::move(client), std::move(mlir_module), + std::move(executable_devices), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def( + "compile_and_load", + [](nb_class_ptr client, std::string mlir_module, + jax::PyDeviceList &py_executable_devices, CompileOptions options, + std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(py_executable_devices.ifrt_device_list()); + return ValueOrThrow(PyClient::CompileAndLoad( + std::move(client), std::move(mlir_module), + std::move(executable_devices), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + // The following two overloads are for users of deprecated APIs who call + // `backend.compile` but do not have visibility to `DeviceList`. + .def( + "compile_and_load", + [](nb_class_ptr client, nb::bytes mlir_module, + nb::sequence &py_executable_devices, CompileOptions options) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(jax::PyDeviceList(nb::tuple(py_executable_devices)) + .ifrt_device_list()); + return ValueOrThrow(PyClient::CompileAndLoad( + std::move(client), + std::string(mlir_module.c_str(), mlir_module.size()), + std::move(executable_devices), std::move(options), + std::vector())); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions()) + .def( + "compile_and_load", + [](nb_class_ptr client, std::string mlir_module, + nb::sequence &py_executable_devices, CompileOptions options) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(jax::PyDeviceList(nb::tuple(py_executable_devices)) + .ifrt_device_list()); + return ValueOrThrow(PyClient::CompileAndLoad( + std::move(client), std::move(mlir_module), + std::move(executable_devices), std::move(options), + std::vector())); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions()) + .def("compile_ifrt_program", + xla::ValueOrThrowWrapper(PyClient::CompileAndLoadIfrtProgram)) + .def("compile_and_load_ifrt_program", + xla::ValueOrThrowWrapper(PyClient::CompileAndLoadIfrtProgram)) + .def("serialize_executable", + xla::ValueOrThrowWrapper(&PyClient::SerializeExecutable)) + .def( + "deserialize_executable", + [](nb_class_ptr client, nb::bytes serialized, + jax::PyDeviceList &py_executable_devices, + std::optional options, + std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(py_executable_devices.ifrt_device_list()); + return ValueOrThrow(PyClient::DeserializeExecutable( + std::move(client), std::move(serialized), + std::move(executable_devices), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("serialized"), nb::arg("executable_devices"), + nb::arg("compile_options").none() = nb::none(), + nb::arg("host_callbacks") = std::vector()) + // The following overload is for users of deprecated APIs who call + // `deserialize_executable` but do not have visibility to `DeviceList`. + .def( + "deserialize_executable", + [](nb_class_ptr client, nb::bytes serialized, + nb::sequence &py_executable_devices, + std::optional options) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(jax::PyDeviceList(nb::tuple(py_executable_devices)) + .ifrt_device_list()); + return ValueOrThrow(PyClient::DeserializeExecutable( + std::move(client), std::move(serialized), + std::move(executable_devices), std::move(options), + std::vector())); + }, + nb::arg("serialized"), nb::arg("executable_devices"), + nb::arg("compile_options").none() = nb::none()) + .def("heap_profile", xla::ValueOrThrowWrapper(&PyClient::HeapProfile)) + // TODO(zhangqiaorjc): Experimental. + .def("defragment", + [](PyClient &self) { xla::ThrowIfError(self.Defragment()); }) + .def("make_python_callback_from_host_send_and_recv", + xla::ValueOrThrowWrapper( + &PyClient::MakePythonCallbackUsingHostSendAndRecv), + nb::arg("callable"), nb::arg("operand_shapes"), + nb::arg("result_shapes"), nb::arg("send_channel_ids"), + nb::arg("recv_channel_ids"), + nb::arg("serializer").none() = nb::none()) + .def( + "get_default_layout", + [](PyClient &self, nb_dtype dtype, nb::sequence shard_shape, + nb_class_ptr device) + -> std::shared_ptr { + ifrt::DType ifrt_type = xla::ValueOrThrow(DtypeToIfRtDType(dtype)); + std::vector dims = SequenceToVector(shard_shape); + return xla::ValueOrThrow(self.ifrt_client()->GetDefaultLayout( + ifrt_type, dims, device->device(), xla::ifrt::MemoryKind())); + }, + nb::arg("dtype"), nb::arg("shard_shape"), nb::arg("device")) + .def("__getattr__", + [](PyClient &client, absl::string_view name) -> nb::object { + const auto &attrs = client.Attributes().map(); + auto it = attrs.find(name); + if (it != attrs.end()) { + return std::visit([](auto &&v) { return nb::cast(v.value); }, + it->second); + } + throw nb::attribute_error( + absl::StrCat("Unknown attribute ", name).c_str()); + }); +} + +} // namespace xla diff --git a/tests/ci_clangformat/py_client.h b/tests/ci_clangformat/py_client.h new file mode 100644 index 0000000..b37ab18 --- /dev/null +++ b/tests/ci_clangformat/py_client.h @@ -0,0 +1,256 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_CLIENT_H_ +#define JAXLIB_PY_CLIENT_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "jaxlib/nb_class_ptr.h" +#include "nanobind/nanobind.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/program.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/shape.h" +#include "llvm/Support/Casting.h" + +namespace xla { + +class PyClient; +class PyLoadedExecutable; +class PyArray; +class PyDevice; +class PyMemorySpace; +struct PyArray_Storage; + +// Python wrapper around PjRtClient. +// We use a wrapper class to add Python-specific functionality. +class PyClient { + public: + static nb_class_ptr Make(std::shared_ptr ifrt_client); + + // Do not call the constructor directly. Use `PyClient::Make` instead. + explicit PyClient(std::shared_ptr ifrt_client); + virtual ~PyClient(); + + ifrt::Client *ifrt_client() const { return ifrt_client_.get(); } + const std::shared_ptr &shared_ptr_ifrt_client() const { + return ifrt_client_; + } + + // Short-term escape hatch to get PjRtClient from PyClient. + // TODO(hyeontaek): Migrate all users of this method to be agnostic of PjRt. + xla::PjRtClient *pjrt_client() const { + auto *pjrt_client = + llvm::dyn_cast_or_null(ifrt_client_.get()); + if (pjrt_client == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return pjrt_client->pjrt_client(); + } + std::shared_ptr shared_ptr_pjrt_client() { + auto *pjrt_client = + llvm::dyn_cast_or_null(ifrt_client_.get()); + if (pjrt_client == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return pjrt_client->shared_ptr_pjrt_client(); + } + + // Legacy alises. + std::shared_ptr shared_pjrt_client() { + return shared_ptr_pjrt_client(); + } + + absl::string_view platform_name() const { + // TODO(phawkins): this is a temporary backwards compatibility shim. We + // changed the name PJRT reports for GPU platforms to "cuda" or "rocm", but + // we haven't yet updated JAX clients that expect "gpu". Migrate users and + // remove this code. + if (ifrt_client_->platform_name() == "cuda" || + ifrt_client_->platform_name() == "rocm") { + return "gpu"; + } else { + return ifrt_client_->platform_name(); + } + } + absl::string_view raw_platform_name() const { + // TODO(parkers): Once platform_name() is the same, remove this. + return ifrt_client_->platform_name(); + } + absl::string_view platform_version() const { + return ifrt_client_->platform_version(); + } + absl::string_view runtime_type() const { + return ifrt_client_->runtime_type(); + } + + // Returns implementation-specific attributes about this client, e.g. the PJRT + // C API version if applicable. + const xla::ifrt::AttributeMap &Attributes() const { + return client_attributes_; + } + + int addressable_device_count() const { + return ifrt_client_->addressable_device_count(); + } + int device_count() const { return ifrt_client_->device_count(); } + int process_index() const { return ifrt_client_->process_index(); } + + std::vector> Devices(); + std::vector> LocalDevices(); + // Returns all devices in the client. Private API; only use this method for + // implementing backend._get_all_devices(). + // TODO(hyeontaek): Remove this method once we have a unified API for + // enumerating devices with different criteria. + std::vector> GetAllDevices(); + absl::StatusOr> DeviceFromLocalHardwareId( + int local_hardware_id); + + // Returns the PyDevice associated with the given ifrt::Device. + nb_class_ptr GetPyDevice(ifrt::Device *device); + + // Returns the PyMemorySpace associated with the given ifrt::Memory. + nb_class_ptr GetPyMemorySpace(ifrt::Memory *memory_space); + + // Returns a vector of live PyArray objects. PyArray objects may share + // PjRtBuffers, so there may be duplicates of the same underlying device + // buffer. + std::vector LiveBuffersOnDevice(ifrt::Device *device); + + nanobind::list LiveExecutables(); + + // TODO(zhangqiaorjc): Remove when we have transparent defragmentation. + absl::Status Defragment(); + + static absl::StatusOr BufferFromPyval( + nb_class_ptr client, nanobind::handle argument, + ifrt::Device *device, bool force_copy, + ifrt::Client::HostBufferSemantics host_buffer_semantics); + + static absl::StatusOr> + CompileAndLoadIfrtProgram(nb_class_ptr client, + std::unique_ptr ifrt_program, + std::unique_ptr ifrt_options); + + static absl::StatusOr> CompileAndLoad( + nb_class_ptr client, std::string mlir_module, + ifrt::DeviceListRef executable_devices, CompileOptions options, + std::vector host_callbacks); + + static absl::StatusOr> CompileAndLoad( + nb_class_ptr client, std::string mlir_module, + ifrt::DeviceListRef executable_devices, CompileOptions options, + std::vector host_callbacks); + + absl::StatusOr SerializeExecutable( + const PyLoadedExecutable &executable) const; + static absl::StatusOr> DeserializeExecutable( + nb_class_ptr client, nanobind::bytes serialized, + ifrt::DeviceListRef executable_devices, + std::optional options, + std::vector host_callbacks); + + absl::StatusOr HeapProfile(); + + // `MakePythonCallbackUsingHostSendAndRecv` takes in an input Python callable + // that takes in arguments of shapes `operand_shapes` and returns results of + // shapes `result_shapes`. The arguments correspond to Send ops in the HLO + // program through `send_channel_ids` and the results correspond to Recv ops + // through `recv_channel_ids`. It returns the host callback as an opaque + // object whose reference will keep the Python callback alive. The host + // callback can be passed to `PyClient::CompileAndLoad` or + // `PyClient::DeserializeExecutable`. The corresponding Send/Recv ops in the + // XLA computation can trigger the execution of this host callback. + // `serializer` is a function that takes `callable` as an argument and returns + // a serialized callable as a string. + // + // The callable receives as arguments NumPy arrays for arguments with array + // types, and None for Token argument. The callable must return a tuple of + // either arrays or None values. + absl::StatusOr MakePythonCallbackUsingHostSendAndRecv( + nanobind::callable callable, absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, + nanobind::callable serializer); + + std::vector LiveArrays() const; + + static void RegisterPythonTypes(nanobind::module_ &m); + + protected: + static void Initialize(nb_class_ptr client); + + private: + friend class PyLoadedExecutable; + friend class PyArray; + friend struct PyArray_Storage; + + static int tp_traverse(PyObject *self, visitproc visit, void *arg); + static int tp_clear(PyObject *self); + static PyType_Slot slots_[]; + + std::shared_ptr ifrt_client_; + xla::ifrt::AttributeMap client_attributes_; + // Pointers to intrusive doubly-linked lists of arrays and executables, used + // to iterate over all known objects when heap profiling. The list structure + // is protected by the GIL. + + nanobind::ft_mutex executables_mutex_; + // List guarded by executables_mutex_. + PyLoadedExecutable *executables_ = nullptr; + +#ifdef NB_FREE_THREADING + static constexpr size_t kNumArraysShards = 16; +#else + static constexpr size_t kNumArraysShards = 1; +#endif + struct ArraysShard { + mutable nanobind::ft_mutex mutex; + PyArray_Storage *arrays; + }; + std::array arrays_; + + absl::flat_hash_map> devices_; + absl::flat_hash_map> + memory_spaces_; +}; + +} // namespace xla + +#endif // JAXLIB_PY_CLIENT_H_ diff --git a/tests/ci_clangformat/py_client_cpu.cc b/tests/ci_clangformat/py_client_cpu.cc new file mode 100644 index 0000000..fe863c8 --- /dev/null +++ b/tests/ci_clangformat/py_client_cpu.cc @@ -0,0 +1,243 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_client_cpu.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "include/dlpack/dlpack.h" +#include "jaxlib/ffi.h" +#include "nanobind/nanobind.h" +#include "xla/ffi/api/ffi.h" +#include "xla/ffi/ffi_api.h" +#include "xla/pjrt/host_callback.h" +#include "xla/pjrt/transpose.h" +#include "xla/primitive_util.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/types.h" +#include "xla/shape_util.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace nb = nanobind; + +namespace xla { + +struct CpuTransposePlanCache { + static ffi::TypeId id; + explicit CpuTransposePlanCache(int capacity) : cache(capacity) {} + xla::TransposePlanCache cache; +}; + +ffi::TypeId CpuTransposePlanCache::id = {}; + +XLA_FFI_REGISTER_TYPE(ffi::GetXlaFfiApi(), "CpuTransposePlanCache", + &CpuTransposePlanCache::id); + +static ffi::ErrorOr> +CpuTransposePlanCacheInstantiate(uint64_t index) { + return std::make_unique(/*capacity=*/16); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + kCpuTransposePlanCacheInstantiate, CpuTransposePlanCacheInstantiate, + ffi::Ffi::BindInstantiate().Attr("index")); + +ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks *callbacks, + CpuTransposePlanCache *transpose_cache, + uint64_t index, ffi::RemainingArgs args, + ffi::RemainingRets rets) { + nb::gil_scoped_acquire gil; + auto callback = nb::borrow( + static_cast(callbacks->callbacks[index])); + auto nb_args = nb::steal(PyTuple_New(args.size())); + for (size_t i = 0; i < args.size(); ++i) { + auto arg = args.get(i); + auto ptype = static_cast(arg->element_type()); + // TODO(b/395428868): Remove this check once we support subbyte types. + if (ptype == S1 || ptype == U1) { + return ffi::Error(ffi::ErrorCode::kUnimplemented, + absl::StrFormat("Unsupported primitive type: %s", + PrimitiveType_Name(ptype))); + } + if (ptype == TOKEN) { + PyTuple_SET_ITEM(nb_args.ptr(), i, nb::none().release().ptr()); + continue; + } + auto maybe_dtype = PrimitiveTypeToNbDtype(ptype); + if (!maybe_dtype.ok()) { + return ffi::Error::Internal(maybe_dtype.status().ToString()); + } + auto dtype = maybe_dtype.value(); + auto dims = absl::Span(arg->dimensions().begin(), + arg->dimensions().size()); + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + std::unique_ptr buffer; + const void *data = arg->untyped_data(); + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): FFI arguments and return buffers are sized assuming + // minimum 1-byte element sizes, even if the data itself is packed. We + // assume that 2-bit and 4-bit types are packed. + size_t size_bytes = arg->element_count() * bits_per_element / 8; + buffer = xla::UnpackIntN(bits_per_element, + static_cast(data), size_bytes); + data = buffer.get(); + } + // We pass in data using default numpy layout i.e., std::nullopt. + auto array = nb_numpy_ndarray(dtype, dims, std::nullopt, data); + array.attr("flags").attr("writeable") = nb::bool_(false); + PyTuple_SET_ITEM(nb_args.ptr(), i, array.release().ptr()); + } + + EnterHostCallback(); + // TODO(dsuo): Change this to use the Python vectorcall protocol, which allows + // you to avoid constructing a tuple for the arguments. + nb::tuple result_tuple; + try { + auto result_object = callback(*nb::borrow(nb_args)); + result_tuple = nb::cast(result_object); + } catch (nb::python_error &e) { + return ffi::Error::Internal( + absl::StrFormat("CpuCallback error calling callback: %s", e.what())); + } + LeaveHostCallback(); + + for (size_t i = 0; i < rets.size(); ++i) { + auto ret = rets.get(i).value(); + auto ptype = static_cast(ret->element_type()); + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + if (ptype == S1 || ptype == U1) { + return ffi::Error(ffi::ErrorCode::kUnimplemented, + absl::StrFormat("Unsupported primitive type: %s", + PrimitiveType_Name(ptype))); + } + if (ptype == TOKEN) continue; + nb::object output = + nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); + nb_numpy_ndarray array = nb_numpy_ndarray::ensure(std::move(output)); + absl::Span strides( + reinterpret_cast(array.strides()), array.ndim()); + // We expect the output to be in default numpy layout. + auto dims = absl::Span(ret->dimensions().begin(), + ret->dimensions().size()); + auto maybe_expected_shape = ShapeUtil::MakeValidatedShape(ptype, dims); + if (!maybe_expected_shape.ok()) { + return ffi::Error::Internal(maybe_expected_shape.status().ToString()); + } + auto expected_shape = maybe_expected_shape.value(); + auto expected_strides = ByteStridesForShape(expected_shape); + + const void *data = array.data(); + std::unique_ptr buffer; + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + size_t size_bytes = array.size() * array.itemsize(); + if (strides != expected_strides) { + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + options.dims = absl::Span( + reinterpret_cast(array.shape()), array.ndim()); + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.dimensions().size()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + auto maybe_plan = transpose_cache->cache.GetOrCreate(options); + if (!maybe_plan.ok()) { + return ffi::Error::Internal(maybe_plan.status().ToString()); + } + auto plan = maybe_plan.value(); + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): If the data needs to be unpacked, don't use return buffer + // supplied by FFI directly. + buffer = std::make_unique(size_bytes); + plan->Execute(data, buffer.get()); + data = buffer.get(); + } else { + plan->Execute(data, ret->untyped_data()); + data = ret->untyped_data(); + } + } + + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + // of packing/unpacking to/from numpy arrays. + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): FFI arguments and return buffers are sized assuming + // minimum 1-byte element sizes, even if the data itself is packed. We + // assume that 2-bit and 4-bit types are packed. + buffer = xla::PackIntN(bits_per_element, static_cast(data), + size_bytes); + data = buffer.get(); + size_bytes = (size_bytes * bits_per_element) / 8; + } + + // Copy data to output buffer if haven't already or modified the data to + // write back. + if (data != ret->untyped_data()) { + std::memcpy(ret->untyped_data(), data, size_bytes); + } + } + + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(kXlaFfiPythonCpuCallback, XlaFfiPythonCpuCallback, + ffi::Ffi::Bind() + .Ctx>() + .Ctx>() + .Attr("index") + .RemainingArgs() + .RemainingRets()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "xla_ffi_python_cpu_callback", + "HOST", + {kCpuTransposePlanCacheInstantiate, nullptr, nullptr, + kXlaFfiPythonCpuCallback}); +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), + "xla_ffi_partitioned_python_cpu_callback", "HOST", + {kCpuTransposePlanCacheInstantiate, nullptr, nullptr, + kXlaFfiPythonCpuCallback}); + +XLA_FFI_DEFINE_HANDLER_SYMBOL(kXlaBufferPythonCpuCallback, + (jax::XlaBufferCallback), + ffi::Ffi::Bind() + .Ctx() + .Ctx() + .Ctx() + .Ctx>() + .Attr("index") + .RemainingArgs() + .RemainingRets()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "xla_buffer_python_cpu_callback", + "HOST", kXlaBufferPythonCpuCallback); + +} // namespace xla diff --git a/tests/ci_clangformat/py_client_cpu.h b/tests/ci_clangformat/py_client_cpu.h new file mode 100644 index 0000000..275a57f --- /dev/null +++ b/tests/ci_clangformat/py_client_cpu.h @@ -0,0 +1,28 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_CLIENT_CPU_H_ +#define JAXLIB_PY_CLIENT_CPU_H_ + +#include "xla/ffi/api/ffi.h" + +namespace xla { + +XLA_FFI_DECLARE_HANDLER_SYMBOL(kCpuTransposePlanCacheInstantiate); +XLA_FFI_DECLARE_HANDLER_SYMBOL(kXlaFfiPythonCpuCallback); + +} // namespace xla + +#endif // JAXLIB_PY_CLIENT_CPU_H_ diff --git a/tests/ci_clangformat/py_compile_only_client.cc b/tests/ci_clangformat/py_compile_only_client.cc new file mode 100644 index 0000000..8a2836a --- /dev/null +++ b/tests/ci_clangformat/py_compile_only_client.cc @@ -0,0 +1,145 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_compile_only_client.h" + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_device_list.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/compile_only_ifrt/client.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/pjrt_ifrt/pjrt_executable.h" +#include "xla/python/pjrt_ifrt/pjrt_topology.h" +#include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "xla/python/version.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/python/lib/core/numpy.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "llvm/Support/Casting.h" + +namespace nb = nanobind; + +namespace xla { + +namespace { + +class CompileOnlyPyClient : public PyClient { + public: + using PyClient::PyClient; + + static nb_class_ptr Make( + std::shared_ptr topology) { + auto client = + nb::borrow>(make_nb_class( + std::make_unique(std::move(topology)))); + CompileOnlyPyClient::Initialize(client); + return client; + } + + absl::StatusOr CompileUnloaded( + absl::string_view mlir_module, ifrt::DeviceListRef executable_devices, + CompileOptions options, std::vector host_callbacks) { + if (!host_callbacks.empty()) { + return Unimplemented( + "Compiling with host_callbacks not available with compile-only " + "client."); + } + nb::gil_scoped_release gil_release; + mlir::MLIRContext context; + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + if (options.executable_build_options.use_shardy_partitioner()) { + // Since Shardy is located in the middle of the XLA pipeline, we need to + // export it before going to HLO while preserving Shardy ops and attrs. + TF_RETURN_IF_ERROR(ExportShardyForHloRoundTrip(*module)); + } + auto *ifrt_client = + llvm::dyn_cast_or_null(this->ifrt_client()); + CHECK(ifrt_client) << "CompileOnlyPyClient requires ifrt_client be a " + "CompileOnlyIfRtClient"; +#if JAX_IFRT_VERSION_NUMBER >= 6 + auto xla_options = std::make_unique( + options, std::move(executable_devices)); +#else + auto xla_options = std::make_unique(options); +#endif + TF_ASSIGN_OR_RETURN(auto executable, + PjRtCompile(std::move(options), module.get(), + *ifrt_client->topology().description())); + TF_ASSIGN_OR_RETURN(auto ifrt_executable, + ifrt::PjRtExecutable::Create(std::move(executable))); + return ifrt::ExecutableRef(std::move(ifrt_executable)); + } + + private: + static void Initialize(nb_class_ptr client) { + PyClient::Initialize(client); + } +}; + +} // namespace + +nb_class_ptr MakeCompileOnlyClient( + std::shared_ptr topology) { + return CompileOnlyPyClient::Make(std::move(topology)); +} + +void RegisterCompileOnlyClient(nb::module_ &m) { + nb::class_(m, "CompileOnlyPyClient") + .def( + "compile", + [](CompileOnlyPyClient &self, nb::bytes mlir_module, + jax::PyDeviceList &py_executable_devices, CompileOptions options, + std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(py_executable_devices.ifrt_device_list()); + return ValueOrThrow(self.CompileUnloaded( + absl::string_view(mlir_module.c_str(), mlir_module.size()), + std::move(executable_devices), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def("compile", + ValueOrThrowWrapper(&CompileOnlyPyClient::CompileUnloaded), + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()); +} + +} // namespace xla diff --git a/tests/ci_clangformat/py_compile_only_client.h b/tests/ci_clangformat/py_compile_only_client.h new file mode 100644 index 0000000..599052b --- /dev/null +++ b/tests/ci_clangformat/py_compile_only_client.h @@ -0,0 +1,45 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_COMPILE_ONLY_CLIENT_H_ +#define JAXLIB_PY_COMPILE_ONLY_CLIENT_H_ + +#include + +// placeholder for index annotation headers +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "nanobind/nanobind.h" +#include "xla/python/pjrt_ifrt/pjrt_topology.h" + +namespace xla { + +// This is a workaround for AOT compilation until topologies and device +// descriptions are better integrated into jax's Python code. It returns a +// PyClient that will return errors for all non-AOT methods. It also exposes a +// different compile method that returns an unloaded executable (vs. PyClient +// usually returns a loaded executable). RegisterCompileOnlyClient() overloads +// the Python "compile" method to return the unloaded executable, and we rely on +// Python duck typing to treat the unloaded executable like a loaded executable +// (except it will raise errors if you try to run it, which is what we want for +// AOT environments). +nb_class_ptr MakeCompileOnlyClient( + std::shared_ptr); + +void RegisterCompileOnlyClient(nanobind::module_ &m); + +} // namespace xla + +#endif // JAXLIB_PY_COMPILE_ONLY_CLIENT_H_ diff --git a/tests/ci_clangformat/py_device.cc b/tests/ci_clangformat/py_device.cc new file mode 100644 index 0000000..d23c860 --- /dev/null +++ b/tests/ci_clangformat/py_device.cc @@ -0,0 +1,350 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_device.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_memory_space.h" +#include "jaxlib/python_ref_manager.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/layout_util.h" +#include "xla/literal.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/nb_helpers.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/types.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/framework/allocator.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "llvm/Support/Casting.h" + +namespace nb = ::nanobind; + +namespace xla { + +PyDevice::PyDevice(nb_class_ptr client, ifrt::Device *device) + : client_(std::move(client)), device_(device) {} + +int PyDevice::id() const { return device_->Id().value(); } + +int PyDevice::process_index() const { return device_->ProcessIndex(); } + +absl::string_view PyDevice::platform() const { + // TODO(phawkins): this is a temporary backwards + // compatibility shim. We changed the name PJRT + // reports for GPU platforms to "cuda" or "rocm", + // but we haven't yet updated JAX clients that + // expect "gpu". Migrate users and remove this + // code. + if (client_->platform_name() == "cuda" || + client_->platform_name() == "rocm") { + return absl::string_view("gpu"); + } else { + return client_->platform_name(); + } +} + +absl::string_view PyDevice::device_kind() const { return device_->Kind(); } + +std::optional PyDevice::local_hardware_id() const { + // TODO(phawkins): consider supporting this for non-PJRT devices. + ifrt::PjRtDevice *device = llvm::dyn_cast(device_); + if (device == nullptr || !device->IsAddressable()) { + return std::nullopt; + } + int local_hardware_id = device->pjrt_device()->local_hardware_id().value(); + if (local_hardware_id == -1) { + return std::nullopt; + } + return local_hardware_id; +} + +absl::string_view PyDevice::Str() const { return device_->DebugString(); } + +absl::string_view PyDevice::Repr() const { return device_->ToString(); } + +absl::Status PyDevice::TransferToInfeed(LiteralSlice literal) { + GlobalPyRefManager()->CollectGarbage(); + nb::gil_scoped_release gil_release; + auto client = llvm::dyn_cast(client_->ifrt_client()); + auto device = llvm::dyn_cast(device_); + if (client == nullptr || device == nullptr) { + return xla::InvalidArgument( + "TransferToInfeed is only supported for PjRt devices."); + } + return client->TransferToInfeed(device, literal); +} + +absl::StatusOr PyDevice::TransferFromOutfeed(Shape shape) { + GlobalPyRefManager()->CollectGarbage(); + std::shared_ptr literal; + { + nb::gil_scoped_release gil_release; + auto client = llvm::dyn_cast(client_->ifrt_client()); + auto device = llvm::dyn_cast(device_); + if (client == nullptr || device == nullptr) { + return xla::InvalidArgument( + "TransferFromOutfeed is only supported for PjRt devices."); + } + ShapeUtil::ForEachMutableSubshape( + &shape, [](Shape *subshape, const ShapeIndex &) { + if (!subshape->has_layout()) { + LayoutUtil::SetToDefaultLayout(subshape); + } + }); + literal = std::make_shared(shape); + TF_RETURN_IF_ERROR(client->TransferFromOutfeed(device, literal.get())); + } + return LiteralToPython(std::move(literal)); +} + +absl::StatusOr> PyDevice::Memory( + absl::string_view kind) const { + ifrt::Memory *result_memory_space = nullptr; + for (auto *memory_space : device_->Memories()) { + if (memory_space->Kind().memory_kind() == kind) { + if (result_memory_space != nullptr) { + std::string memories = absl::StrJoin( + device_->Memories(), ", ", + [](std::string *out, const auto &memory_space) { + absl::StrAppend(out, *memory_space->Kind().memory_kind()); + }); + auto device_kind = device_->Kind(); + return xla::InvalidArgument( + "Found more than one addressable memory for " + "kind %s which is not allowed. There can only " + "be one memory for each " + "kind. Device %s can address the following " + "memory kinds: %s", + kind, device_kind, memories); + } + result_memory_space = memory_space; + } + } + if (result_memory_space == nullptr) { + std::string memories = absl::StrJoin( + device_->Memories(), ", ", + [](std::string *out, const auto &memory_space) { + absl::StrAppend(out, *memory_space->Kind().memory_kind()); + }); + auto device_kind = device_->Kind(); + return xla::InvalidArgument( + "Could not find memory addressable by device %s. Device %s " + "can address the following memory kinds: %s. " + "Got memory kind: %s", + device_kind, device_kind, memories, kind); + } + return client_->GetPyMemorySpace(result_memory_space); +} + +absl::StatusOr> PyDevice::DefaultMemory() const { + TF_ASSIGN_OR_RETURN(auto *memory_space, device_->DefaultMemory()); + return client_->GetPyMemorySpace(memory_space); +} + +nb::list PyDevice::AddressableMemories() const { + nb::list memory_spaces; + for (auto *memory_space : device_->Memories()) { + memory_spaces.append(client_->GetPyMemorySpace(memory_space)); + } + return memory_spaces; +} + +absl::StatusOr> PyDevice::MemoryStats() const { + GlobalPyRefManager()->CollectGarbage(); + ifrt::PjRtDevice *device = llvm::dyn_cast(device_); + if (device == nullptr || !device->IsAddressable()) { + return xla::InvalidArgument( + "MemoryStats is only supported for addressable PjRt devices."); + } + absl::StatusOr maybe_stats = + device->pjrt_device()->GetAllocatorStats(); + if (absl::IsUnimplemented(maybe_stats.status())) { + return std::nullopt; + } + // Raise error if any status other than Unimplemented is returned. + ThrowIfError(maybe_stats.status()); + + nb::dict result; + result["num_allocs"] = maybe_stats->num_allocs; + result["bytes_in_use"] = maybe_stats->bytes_in_use; + result["peak_bytes_in_use"] = maybe_stats->peak_bytes_in_use; + result["largest_alloc_size"] = maybe_stats->largest_alloc_size; + if (maybe_stats->bytes_limit) { + result["bytes_limit"] = *maybe_stats->bytes_limit; + } + result["bytes_reserved"] = maybe_stats->bytes_reserved; + result["peak_bytes_reserved"] = maybe_stats->peak_bytes_reserved; + if (maybe_stats->bytes_reservable_limit) { + result["bytes_reservable_limit"] = *maybe_stats->bytes_reservable_limit; + } + result["largest_free_block_bytes"] = maybe_stats->largest_free_block_bytes; + if (maybe_stats->pool_bytes) { + result["pool_bytes"] = *maybe_stats->pool_bytes; + } + if (maybe_stats->peak_pool_bytes) { + result["peak_pool_bytes"] = *maybe_stats->peak_pool_bytes; + } + return result; +} + +absl::StatusOr PyDevice::GetStreamForExternalReadyEvents() + const { + ifrt::PjRtDevice *device = llvm::dyn_cast(device_); + if (device == nullptr || !device->IsAddressable()) { + return xla::InvalidArgument( + "GetStreamForExternalReadyEvents is only supported for addressable " + "PjRt devices."); + } + return device->pjrt_device()->GetStreamForExternalReadyEvents(); +} + +/* static */ int PyDevice::tp_traverse(PyObject *self, visitproc visit, + void *arg) { + PyDevice *d = nb::inst_ptr(self); + Py_VISIT(d->client().ptr()); + return 0; +} + +/* static */ int PyDevice::tp_clear(PyObject *self) { + PyDevice *d = nb::inst_ptr(self); + nb_class_ptr client; + std::swap(client, d->client_); + return 0; +} + +PyType_Slot PyDevice::slots_[] = { + {Py_tp_traverse, (void *)PyDevice::tp_traverse}, + {Py_tp_clear, (void *)PyDevice::tp_clear}, + {0, nullptr}, +}; + +/* static */ void PyDevice::RegisterPythonType(nb::module_ &m) { + nb::class_ device( + m, "Device", nb::type_slots(PyDevice::slots_), + "A descriptor of an available device.\n\nSubclasses are used to " + "represent specific types of devices, e.g. CPUs, GPUs. Subclasses may " + "have additional properties specific to that device type."); + device + .def_prop_ro( + "id", &PyDevice::id, + "Integer ID of this device.\n\nUnique across all available devices " + "of this type, including remote devices on multi-host platforms.") + .def_prop_ro("process_index", &PyDevice::process_index, + "Integer index of this device's process.\n\n" + "This is always 0 except on multi-process platforms.") + .def_prop_ro("host_id", &PyDevice::process_index, + "Deprecated; please use process_index") + .def_prop_ro("task_id", &PyDevice::process_index, + "Deprecated; please use process_index") + .def_prop_ro("platform", &PyDevice::platform) + .def_prop_ro("device_kind", &PyDevice::device_kind) + .def_prop_ro("client", &PyDevice::client) + .def_prop_ro( + "local_hardware_id", &PyDevice::local_hardware_id, + "Opaque hardware ID, e.g., the CUDA device number. In general, not " + "guaranteed to be dense, and not guaranteed to be defined on all " + "platforms.") + .def("__str__", &PyDevice::Str) + .def("__repr__", &PyDevice::Repr) + .def("transfer_to_infeed", + ThrowIfErrorWrapper(&PyDevice::TransferToInfeed)) + .def("transfer_from_outfeed", + ValueOrThrowWrapper(&PyDevice::TransferFromOutfeed)) + .def("memory", ValueOrThrowWrapper(&PyDevice::Memory), nb::arg("kind")) + .def("default_memory", ValueOrThrowWrapper(&PyDevice::DefaultMemory), + "Returns the default memory of a device.") + .def("addressable_memories", &PyDevice::AddressableMemories, + "Returns all the memories that a device can address.") + + .def("live_buffers", + [](nb::handle device) { + PythonDeprecationWarning( + /*stacklevel=*/1, + "Per device live_buffers() is deprecated. Please " + "use the jax.live_arrays() for jax.Arrays instead."); + return nb::list(); + }) + .def( + "memory_stats", ValueOrThrowWrapper(&PyDevice::MemoryStats), + "Returns memory statistics for this device keyed by name. May not " + "be implemented on all platforms, and different platforms may return " + "different stats, or -1 for unavailable stats. 'bytes_in_use' is " + "usually available. Intended for diagnostic use.") + .def( + "get_stream_for_external_ready_events", + xla::ValueOrThrowWrapper(&PyDevice::GetStreamForExternalReadyEvents)); + static PyMethodDef get_attr_method = { + "__getattr__", + +[](PyObject *self, PyObject *args) -> PyObject * { + PyObject *key; + if (!PyArg_ParseTuple(args, "O", &key)) { + PyErr_SetString(PyExc_TypeError, "__getattr__ must take 1 argument."); + return nullptr; + } + try { + auto device = nb::cast(nb::handle(self)); + auto name = nb::cast(nb::handle(key)); + const auto &attrs = device->device_->Attributes().map(); + auto it = attrs.find(name); + if (it != attrs.end()) { + auto result = std::visit([](auto &&v) { return nb::cast(v.value); }, + it->second); + return result.release().ptr(); + } + PyErr_SetNone(PyExc_AttributeError); + return nullptr; + } catch (std::exception &e) { + PyErr_Format(PyExc_SystemError, "Unhandled nanobind exception: %s", + e.what()); + return nullptr; + } catch (...) { + PyErr_SetString(PyExc_SystemError, "Unhandled nanobind exception."); + return nullptr; + } + }, + METH_VARARGS, + nullptr, + }; + device.attr("__getattr__") = nb::steal(PyDescr_NewMethod( + reinterpret_cast(device.ptr()), &get_attr_method)); +} + +} // namespace xla diff --git a/tests/ci_clangformat/py_device.h b/tests/ci_clangformat/py_device.h new file mode 100644 index 0000000..6cae95a --- /dev/null +++ b/tests/ci_clangformat/py_device.h @@ -0,0 +1,83 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_DEVICE_H_ +#define JAXLIB_PY_DEVICE_H_ + +#include + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "nanobind/nanobind.h" +#include "xla/literal.h" +#include "xla/python/ifrt/device.h" +#include "xla/shape.h" + +namespace xla { + +class PyDevice { + public: + PyDevice(nb_class_ptr client, ifrt::Device *device); + + // Devices are compared using Python object identity, so we don't allow them + // to be copied or moved. + PyDevice(const PyDevice &) = delete; + PyDevice(PyDevice &&) = delete; + PyDevice &operator=(const PyDevice &) = delete; + PyDevice &operator=(PyDevice &&) = delete; + + const nb_class_ptr &client() const { return client_; } + ifrt::Device *device() const { return device_; } + + int id() const; + int process_index() const; + absl::string_view platform() const; + absl::string_view device_kind() const; + std::optional local_hardware_id() const; + + absl::string_view Str() const; + absl::string_view Repr() const; + + absl::Status TransferToInfeed(LiteralSlice literal); + absl::StatusOr TransferFromOutfeed(Shape shape); + + absl::StatusOr> Memory( + absl::string_view kind) const; + absl::StatusOr> DefaultMemory() const; + nanobind::list AddressableMemories() const; + absl::StatusOr> MemoryStats() const; + + absl::StatusOr GetStreamForExternalReadyEvents() const; + + static void RegisterPythonType(nanobind::module_ &m); + + private: + static int tp_traverse(PyObject *self, visitproc visit, void *arg); + static int tp_clear(PyObject *self); + static PyType_Slot slots_[]; + + nb_class_ptr client_; + ifrt::Device *device_; +}; + +} // namespace xla + +#endif // JAXLIB_PY_DEVICE_H_ diff --git a/tests/ci_clangformat/py_device_list.cc b/tests/ci_clangformat/py_device_list.cc new file mode 100644 index 0000000..298f376 --- /dev/null +++ b/tests/ci_clangformat/py_device_list.cc @@ -0,0 +1,482 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_device_list.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/log/check.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_device.h" +#include "jaxlib/python_ref_manager.h" +#include "nanobind/make_iterator.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/set.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/types.h" +#include "xla/util.h" + +namespace jax { + +namespace nb = ::nanobind; + +PyDeviceList::PyDeviceList(xla::nb_class_ptr py_client, + xla::ifrt::DeviceListRef device_list) + : py_client_(std::move(py_client)), device_list_(std::move(device_list)) {} + +PyDeviceList::PyDeviceList(nb::tuple py_device_assignment) + : device_list_(py_device_assignment) { + // Attempt to convert to Python devices into `ifrt::DeviceList`. + if (py_device_assignment.size() == 0) { + return; + } + absl::InlinedVector devices; + devices.reserve(py_device_assignment.size()); + for (nb::handle obj : py_device_assignment) { + if (!nb::isinstance(obj.ptr())) { + // Non-`xla::PyDevice` is used on an alternative JAX backend with device + // duck typing. Use Python device objects already set in `device_list_`. + return; + } + auto py_device = nb::cast(obj); + if (py_client_.get() == nullptr) { + py_client_ = py_device->client(); + } else if (py_device->client().get() != py_client_.get()) { + // If the list contains multiple clients, fall back to device duck typing. + return; + } + devices.push_back(py_device->device()); + } + device_list_ = py_client_->ifrt_client()->MakeDeviceList(devices); +} + +PyDeviceList::~PyDeviceList() { + if (device_list_.index() == 1) { + xla::GlobalPyRefManager()->AddGarbage( + std::move(std::get<1>(std::move(device_list_)))); + } +} + +absl::StatusOr PyDeviceList::ifrt_device_list() + const { + switch (device_list_.index()) { + case 0: + return std::get<0>(device_list_); + case 1: + return xla::InvalidArgument("DeviceList contains non-IFRT devices"); + default: + return xla::InvalidArgument("Unrecognized DeviceList type"); + } +} + +int64_t PyDeviceList::Hash() { + if (!hash_.has_value()) { + switch (device_list_.index()) { + case 0: + hash_ = absl::HashOf(std::get<0>(device_list_)); + break; + case 1: + hash_ = nb::hash(std::get<1>(device_list_)); + break; + default: + throw nb::value_error("Unrecognized DeviceList type"); + } + } + return *hash_; +} + +/*static*/ bool PyDeviceList::Equal(xla::nb_class_ptr self, + nb::handle other) { + if (!nb::isinstance(other)) { + return false; + } + auto o = nb::cast(other); + // Fast-path using a pointer equality check. + if (self.get() == o) { + return true; + } + int64_t h1, h2; + { + nb::ft_object_guard lock(self); + h1 = self->Hash(); + } + { + nb::ft_object_guard lock(other); + h2 = o->Hash(); + } + if (h1 != h2) { + return false; + } + if (self->device_list_.index() == 0 && o->device_list_.index() == 0) { + nb::gil_scoped_release gil_release; + return *std::get<0>(self->device_list_) == *std::get<0>(o->device_list_); + } else { + return self->AsTuple().equal(o->AsTuple()); + } +} + +/*static*/ bool PyDeviceList::NotEqual(xla::nb_class_ptr self, + nb::handle other) { + return !Equal(std::move(self), other); +} + +int PyDeviceList::Len() const { + switch (device_list_.index()) { + case 0: + return std::get<0>(device_list_)->size(); + case 1: + return nb::len(std::get<1>(device_list_)); + default: + throw nb::value_error("Unrecognized DeviceList type"); + } +} + +nb::object PyDeviceList::GetItem(int index) { + switch (device_list_.index()) { + case 0: { + const xla::ifrt::DeviceListRef &device_list = std::get<0>(device_list_); + if (index < -device_list->size() || index >= device_list->size()) { + throw nb::index_error(); + } else if (index < 0) { + index += device_list->size(); + } + return py_client_->GetPyDevice(device_list->devices()[index]); + } + case 1: + return std::get<1>(device_list_).attr("__getitem__")(index); + default: + throw nb::value_error("Unrecognized DeviceList type"); + } +} + +nb::object PyDeviceList::GetSlice(nb::slice slice) { + switch (device_list_.index()) { + case 0: { + const xla::ifrt::DeviceListRef &device_list = std::get<0>(device_list_); + const absl::Span devices = + device_list->devices(); + Py_ssize_t start, stop, step, slicelength; + if (PySlice_GetIndicesEx(slice.ptr(), devices.size(), &start, &stop, + &step, &slicelength) != 0) { + throw nb::python_error(); + } + nb::tuple out = nb::steal(PyTuple_New(slicelength)); + for (size_t i = 0; i < slicelength; ++i) { + nb::object d = py_client_->GetPyDevice(devices[start]); + PyTuple_SET_ITEM(out.ptr(), i, d.release().ptr()); + start += step; + } + return std::move(out); + } + case 1: + return std::get<1>(device_list_).attr("__getitem__")(slice); + default: + throw nb::value_error("Unrecognized DeviceList type"); + } +} + +nb::tuple PyDeviceList::AsTuple() const { + switch (device_list_.index()) { + case 0: { + const xla::ifrt::DeviceListRef &device_list = std::get<0>(device_list_); + nb::tuple out = nb::steal(PyTuple_New(device_list->size())); + int i = 0; + for (xla::ifrt::Device *device : device_list->devices()) { + nb::object d = py_client_->GetPyDevice(device); + PyTuple_SET_ITEM(out.ptr(), i, d.release().ptr()); + ++i; + } + return out; + } + case 1: + return std::get<1>(device_list_); + default: + throw nb::value_error("Unrecognized DeviceList type"); + } +} + +nb::iterator PyDeviceList::Iter() { + switch (device_list_.index()) { + case 0: { + // Iterator whose deference converts `xla::ifrt::Device*` into JAX + // `PjRtDevice`. + struct Iterator { + void operator++() { ++it; } + bool operator==(const Iterator &other) const { return it == other.it; } + xla::nb_class_ptr operator*() const { + return py_client->GetPyDevice(*it); + } + xla::nb_class_ptr py_client; + absl::Span::const_iterator it; + }; + return nb::make_iterator( + nb::type(), "ifrt_device_iterator", + Iterator{py_client_, std::get<0>(device_list_)->devices().cbegin()}, + Iterator{py_client_, std::get<0>(device_list_)->devices().cend()}); + } + case 1: + return nb::make_iterator( + nb::type(), "python_device_iterator", + std::get<1>(device_list_).begin(), std::get<1>(device_list_).end()); + default: + throw nb::value_error("Unrecognized DeviceList type"); + } +} + +std::string PyDeviceList::Str() { + return nb::cast(nb::str(AsTuple())); +} + +nb::tuple PyDeviceList::Dump() const { return AsTuple(); } + +bool PyDeviceList::IsFullyAddressable() { + if (!is_fully_addressable_.has_value()) { + ProcessIndices(); + CHECK(process_indices_.has_value()); + if (process_indices_->size() > 1) { + is_fully_addressable_ = false; + } else { + CHECK_EQ(process_indices_->size(), 1); + int process_index; + switch (device_list_.index()) { + case 0: { + process_index = py_client_ ? py_client_->process_index() : 0; + break; + } + case 1: { + process_index = + nb::cast(std::get<1>(device_list_)[0].attr("client").attr( + "process_index")()); + break; + } + default: + throw nb::value_error("Unrecognized DeviceList type"); + } + is_fully_addressable_ = *process_indices_->begin() == process_index; + } + } + return *is_fully_addressable_; +} + +/*static*/ xla::nb_class_ptr PyDeviceList::AddressableDeviceList( + xla::nb_class_ptr self) { + nb::ft_object_guard lock(self); + if (self->IsFullyAddressable()) { + // Do not cache this result in `addressable_device_list_`. Otherwise, it + // will create a cycle that prevents deletion of this object. + return self; + } + if (!self->addressable_device_list_.has_value()) { + switch (self->device_list_.index()) { + case 0: { + absl::InlinedVector addressable_devices; + const int process_index = + self->py_client_ ? self->py_client_->process_index() : 0; + for (xla::ifrt::Device *device : + std::get<0>(self->device_list_)->devices()) { + if (device->ProcessIndex() == process_index) { + addressable_devices.push_back(device); + } + } + self->addressable_device_list_ = xla::make_nb_class( + self->py_client_, self->py_client_->ifrt_client()->MakeDeviceList( + addressable_devices)); + break; + } + case 1: { + auto device_list = std::get<1>(self->device_list_); + std::vector addressable_devices; + for (size_t i = 0; i < device_list.size(); ++i) { + nb::object device = device_list[i]; + if (nb::cast(device.attr("process_index")) == + nb::cast(device.attr("client").attr("process_index")())) { + addressable_devices.push_back(std::move(device)); + } + } + self->addressable_device_list_ = xla::make_nb_class( + xla::MutableSpanToNbTuple(absl::MakeSpan(addressable_devices))); + break; + } + default: + throw nb::value_error("Unrecognized DeviceList type"); + } + } + return *self->addressable_device_list_; +} + +const std::set &PyDeviceList::ProcessIndices() { + if (!process_indices_.has_value()) { + process_indices_ = std::set{}; + switch (device_list_.index()) { + case 0: { + for (const xla::ifrt::Device *device : + std::get<0>(device_list_)->devices()) { + process_indices_->insert(device->ProcessIndex()); + } + break; + } + case 1: { + for (nb::handle device : std::get<1>(device_list_)) { + process_indices_->insert(nb::cast(device.attr("process_index"))); + } + break; + } + default: + throw nb::value_error("Unrecognized DeviceList type"); + } + } + return *process_indices_; +} + +void PyDeviceList::PopulateMemoryKindInfo() { + if (device_list_.index() == 1) { + // Handle Python duck-type devices in a separate function for readability. + PopulateMemoryKindInfoForDuckTypedDevices(); + return; + } + if (device_list_.index() != 0) { + throw nb::value_error("Unrecognized DeviceList type"); + } + MemoryKindInfo info; + if (std::get<0>(device_list_)->size() == 0) { + info.default_memory_kind = nb::none(); + memory_kind_info_ = std::move(info); + return; + } + xla::ifrt::Device *device = std::get<0>(device_list_)->devices()[0]; + + auto default_memory = device->DefaultMemory(); + if (!default_memory.ok()) { + // Cache the error. + memory_kind_info_ = default_memory.status(); + return; + } + info.default_memory_kind = nb::cast(*(*default_memory)->Kind().memory_kind()); + nb::tuple memory_kinds = + nb::steal(PyTuple_New(device->Memories().size())); + for (size_t i = 0; i < device->Memories().size(); ++i) { + auto *memory = device->Memories()[i]; + nb::str s = nb::str(memory->Kind().memory_kind()->data(), + memory->Kind().memory_kind()->size()); + PyTuple_SET_ITEM(memory_kinds.ptr(), i, s.release().ptr()); + } + info.memory_kinds = std::move(memory_kinds); + memory_kind_info_ = std::move(info); +} + +void PyDeviceList::PopulateMemoryKindInfoForDuckTypedDevices() { + MemoryKindInfo info; + try { + if (std::get<1>(device_list_).size() == 0) { + info.default_memory_kind = nb::none(); + // info.memory_kinds is default-initialized to an empty tuple. + memory_kind_info_ = std::move(info); + return; + } + nb::handle device = std::get<1>(device_list_)[0]; + auto default_memory = device.attr("default_memory")(); + info.default_memory_kind = default_memory.attr("kind"); + info.memory_kinds = + nb::tuple(nb::object(device.attr("addressable_memories")())); + memory_kind_info_ = std::move(info); + } catch (nb::python_error &e) { + // Cache the error. + memory_kind_info_ = xla::InvalidArgument("%s", e.what()); + } +} + +/*static*/ absl::StatusOr PyDeviceList::MemoryKinds( + xla::nb_class_ptr self) { + nb::ft_object_guard lock(self); + if (!self->memory_kind_info_.has_value()) { + self->PopulateMemoryKindInfo(); + } + if (!self->memory_kind_info_->ok()) { + return self->memory_kind_info_->status(); + } + return (*self->memory_kind_info_)->memory_kinds; +} + +/*static*/ absl::StatusOr PyDeviceList::DefaultMemoryKind( + xla::nb_class_ptr self) { + nb::ft_object_guard lock(self); + if (!self->memory_kind_info_.has_value()) { + self->PopulateMemoryKindInfo(); + } + if (!self->memory_kind_info_->ok()) { + return self->memory_kind_info_->status(); + } + return (*self->memory_kind_info_)->default_memory_kind; +} + +/*static*/ void PyDeviceList::Register(nb::module_ &m) { + nb::class_(m, "DeviceList") + .def(nb::init()) + .def("__hash__", &PyDeviceList::Hash, nb::lock_self()) + .def("__eq__", &PyDeviceList::Equal) + .def("__ne__", &PyDeviceList::NotEqual) + .def("__len__", &PyDeviceList::Len) + .def("__getitem__", &PyDeviceList::GetItem) + .def("__getitem__", &PyDeviceList::GetSlice) + .def("__iter__", &PyDeviceList::Iter, nb::keep_alive<0, 1>()) + .def("__str__", &PyDeviceList::Str) + .def("__repr__", &PyDeviceList::Str) + .def("__getstate__", [](const PyDeviceList &l) { return l.Dump(); }) + .def("__setstate__", + [](PyDeviceList &self, nb::tuple t) { + new (&self) PyDeviceList(std::move(t)); + }) + .def_prop_ro("is_fully_addressable", &PyDeviceList::IsFullyAddressable, + nb::lock_self()) + .def_prop_ro("addressable_device_list", + &PyDeviceList::AddressableDeviceList) + .def_prop_ro("process_indices", &PyDeviceList::ProcessIndices, + nb::lock_self()) + // `xla::ValueOrThrowWrapper` does not work with + // `def_prop_ro()`. Manually convert an error into an exception. + .def_prop_ro("default_memory_kind", + [](xla::nb_class_ptr l) { + auto kind = DefaultMemoryKind(l); + if (!kind.ok()) { + throw nb::value_error(kind.status().ToString().c_str()); + } + return *kind; + }) + .def_prop_ro("memory_kinds", [](xla::nb_class_ptr l) { + auto kinds = MemoryKinds(l); + if (!kinds.ok()) { + throw nb::value_error(kinds.status().ToString().c_str()); + } + return *kinds; + }); +} + +} // namespace jax diff --git a/tests/ci_clangformat/py_device_list.h b/tests/ci_clangformat/py_device_list.h new file mode 100644 index 0000000..7037674 --- /dev/null +++ b/tests/ci_clangformat/py_device_list.h @@ -0,0 +1,142 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_DEVICE_LIST_H_ +#define JAXLIB_PY_DEVICE_LIST_H_ + +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "nanobind/nanobind.h" +#include "xla/python/ifrt/device_list.h" + +namespace jax { + +// Device list with various caching and direct access to IFRT DeviceList. +class PyDeviceList { + public: + PyDeviceList(xla::nb_class_ptr py_client, + xla::ifrt::DeviceListRef device_list); + explicit PyDeviceList(nanobind::tuple py_device_assignment); + ~PyDeviceList(); + + PyDeviceList(const PyDeviceList &) = delete; + PyDeviceList(PyDeviceList &&) = delete; + PyDeviceList &operator=(const PyDeviceList &) = delete; + PyDeviceList &operator=(PyDeviceList &&) = delete; + + static nanobind::handle type() { + static auto type = nanobind::type(); + return type; + } + + // These two methods are safe to call from C++ without GIL. + xla::nb_class_ptr py_client() const { return py_client_; } + absl::StatusOr ifrt_device_list() const; + + int Len() const; // Requires the GIL in GIL mode. + nanobind::object GetItem(int index); // Requires the GIL in GIL mode. + + // Requires the GIL in GIL mode. Acquires the self lock in non-GIL mode. + static xla::nb_class_ptr AddressableDeviceList( + xla::nb_class_ptr self); + + // Requires the GIL in GIL mode. Acquires the self lock in non-GIL mode. + static absl::StatusOr DefaultMemoryKind( + xla::nb_class_ptr self); + + // Requires the GIL in GIL mode. Acquires the self lock in non-GIL mode. + static absl::StatusOr MemoryKinds( + xla::nb_class_ptr self); + + // go/pywald-pybind-annotation BEGIN + // refs { + // module_path: "third_party/py/jax/jaxlib/xla.cc" + // module_arg {} + // } + // go/pywald-pybind-annotation END + static void Register(nanobind::module_ &m); + + private: + nanobind::tuple AsTuple() const; + + // Methods below require GIL. + nanobind::object GetSlice(nanobind::slice slice); + nanobind::iterator Iter(); + + std::string Str(); + + nanobind::tuple Dump() const; + + int64_t Hash(); // Mutates hash_, needs self lock. + + static bool Equal(xla::nb_class_ptr self, + nanobind::handle other); + static bool NotEqual(xla::nb_class_ptr self, + nanobind::handle other); + + // Finds the memory kind info from an addressable device. Requires the GIL + // or self lock. + void PopulateMemoryKindInfo(); + // Same as `PopulateMemoryKindInfo()`, but uses `py_device_assignment_` + // instead of `ifrt_device_list_` to support duck-typed device objects. + // Requires the GIL or self lock. + void PopulateMemoryKindInfoForDuckTypedDevices(); + + // Requires the self lock or GIL is held. + bool IsFullyAddressable(); + + // Requires the self lock or GIL. + const std::set &ProcessIndices(); + + // Valid only if `device_list_` contains `xla::ifrt::DeviceList` and + // non-empty. + xla::nb_class_ptr py_client_; + + // Either C++ `ifrt::DeviceList` or Python duck-type devices. + // TODO(hyeontaek): Remove support for Python duck-type devices once all + // JAX backends and tests are migrated to use an `xla::ifrt::Device` type + // for JAX devices. + // Immutable after constructor; no locking needed. + std::variant device_list_; + + // Populated on demand. Guarded by the object's self lock. + std::optional hash_; + // TODO(hyeontaek): Make the following property cached within + // `xla::ifrt::DeviceList`. + // Populated on demand. Guarded by the object's self lock. + std::optional is_fully_addressable_; + // Populated on demand. Guarded by the object's self lock. + std::optional> addressable_device_list_; + // Populated on demand. Guarded by the object's self lock. + std::optional> process_indices_; + + struct MemoryKindInfo { + nanobind::object default_memory_kind; + nanobind::tuple memory_kinds; + }; + // Populated on demand. Guarded by the object's self lock. + std::optional> memory_kind_info_; +}; + +} // namespace jax + +#endif // JAXLIB_PY_DEVICE_LIST_H_ diff --git a/tests/ci_clangformat/py_executable.cc b/tests/ci_clangformat/py_executable.cc new file mode 100644 index 0000000..993b007 --- /dev/null +++ b/tests/ci_clangformat/py_executable.cc @@ -0,0 +1,427 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_executable.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/casts.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_device.h" +#include "jaxlib/traceback.h" +#include "nanobind/nanobind.h" +#include "tsl/platform/fingerprint.h" +#include "tsl/profiler/lib/traceme.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +namespace nb = nanobind; + +absl::Status PyToken::Await() { + CHECK(future_.IsValid()); + nb::gil_scoped_release gil_release; + return future_.Await(); +} + +absl::Status PyShardedToken::Await() { + nb::gil_scoped_release gil_release; + absl::Status status = absl::OkStatus(); + for (auto &future : futures_) { + auto s = future.Await(); + if (!s.ok()) status = std::move(s); + } + return status; +} + +PyLoadedExecutable::PyLoadedExecutable( + nb_class_ptr client, + ifrt::LoadedExecutableRef ifrt_loaded_executable, + std::optional traceback, + std::optional fingerprint) + : client_(std::move(client)), + ifrt_loaded_executable_(std::move(ifrt_loaded_executable)), + traceback_(std::move(traceback)), + fingerprint_(std::move(fingerprint)), + next_launch_id_( + fingerprint_.has_value() ? tsl::Fingerprint32(*fingerprint_) : 1) { + CHECK(PyGILState_Check()); + if (fingerprint_) { + VLOG(1) << "Fingerprint for executable " << ifrt_loaded_executable_->name() + << ": " << *fingerprint_; + } + nb::ft_lock_guard lock(client_->executables_mutex_); + next_ = client_->executables_; + client_->executables_ = this; + prev_ = nullptr; + if (next_) { + next_->prev_ = this; + } +} + +PyLoadedExecutable::~PyLoadedExecutable() { + CHECK(PyGILState_Check()); + nb::ft_lock_guard lock(client_->executables_mutex_); + if (client_->executables_ == this) { + client_->executables_ = next_; + } + if (prev_) { + prev_->next_ = next_; + } + if (next_) { + next_->prev_ = prev_; + } +} + +std::vector> PyLoadedExecutable::AddressableDevices() + const { + std::vector> devices; + devices.reserve(ifrt_loaded_executable_->addressable_devices().size()); + for (ifrt::Device *device : ifrt_loaded_executable_->addressable_devices()) { + devices.push_back(client_->GetPyDevice(device)); + } + return devices; +} + +namespace { + +static int GetNumDevices(const ExecuteShardedArg &arg) { + if (std::holds_alternative(arg)) { + return std::get(arg).num_addressable_shards(); + } else { + return std::get>(arg).size(); + } +} +static ifrt::ArrayRef GetIfRtArray(const ExecuteShardedArg &arg) { + if (std::holds_alternative(arg)) { + return tsl::FormRef(std::get(arg).ifrt_array()); + } + auto &arg_vector = std::get>(arg); + + // TODO(hyeontaek): This on-demand Array creation is not efficient and has + // insufficient information about the shape (a dummy shape is used). This + // should be removed if possible and only be used in the context where the + // shape information is unused. + std::vector ifrt_arrays; + ifrt_arrays.reserve(arg_vector.size()); + absl::InlinedVector devices; + devices.reserve(arg_vector.size()); + for (auto &arr : arg_vector) { + CHECK_EQ(arr.ifrt_array()->sharding().devices()->size(), 1) + << arr.ifrt_array()->sharding().DebugString(); + ifrt_arrays.push_back(tsl::FormRef(arr.ifrt_array())); + devices.push_back( + arr.ifrt_array()->sharding().devices()->devices().front()); + } + CHECK(!ifrt_arrays.empty()); + // Use a dummy shape. + // TODO(hyeontaek): Find a way to compute a correct shape. + // TODO(yashkatariya): Plumb sharding or memory_kind here. + ifrt::Client *client = ifrt_arrays.front()->client(); + auto ifrt_array = client->AssembleArrayFromSingleDeviceArrays( + ifrt_arrays.front()->shape(), + ifrt::OpaqueSharding::Create(client->MakeDeviceList(devices), + ifrt::MemoryKind()), + absl::MakeSpan(ifrt_arrays), ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); + TF_CHECK_OK(ifrt_array.status()); + return *ifrt_array; +} + +void PopulateExecuteShardedResults(const nb_class_ptr &client, + std::vector ifrt_arrays, + const PjRtFuture<> &result_status, + int num_computations, + std::vector> &outputs) { + auto traceback = Traceback::Get(); + DCHECK_GT(num_computations, 0); + int num_output_buffers = ifrt_arrays.size(); + outputs.resize(num_output_buffers); + for (int buffer_id = 0; buffer_id < num_output_buffers; ++buffer_id) { + outputs[buffer_id].reserve(num_computations); + auto exploded_arrays = + ifrt_arrays[buffer_id]->DisassembleIntoSingleDeviceArrays( + ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); + TF_CHECK_OK(exploded_arrays.status()); + for (auto &exploded_array : *exploded_arrays) { + outputs[buffer_id].push_back(PyArray::MakeFromSingleDeviceArray( + client, traceback, std::move(exploded_array), false, true, + result_status)); + } + } +} + +absl::StatusOr ExecuteShardedOnLocalDevicesInternal( + const ifrt::ExecuteOptions &options, const nb_class_ptr &client, + ifrt::LoadedExecutable *ifrt_loaded_executable, + absl::Span args, + std::optional>> &returned_futures) { + std::vector output_arrays; + std::unique_ptr> returned_future; + int num_computations = ifrt_loaded_executable->addressable_devices().size(); + PjRtFuture<> result_status; + { + nb::gil_scoped_release gil_release; + for (const auto &arg : args) { + if (GetNumDevices(arg) != num_computations) { + return InvalidArgument( + "Expected args to execute_sharded_on_local_devices to have %d " + "shards, got: [%s]", + num_computations, + absl::StrJoin(args, ", ", + [](std::string *out, const ExecuteShardedArg &arg) { + out->append(std::to_string(GetNumDevices(arg))); + })); + } + } + std::vector arg_arrays(args.size()); + absl::c_transform(args, arg_arrays.begin(), + [&](const ExecuteShardedArg &arg) mutable { + return GetIfRtArray(arg); + }); + TF_ASSIGN_OR_RETURN(auto result, ifrt_loaded_executable->Execute( + absl::MakeSpan(arg_arrays), options, + /*devices=*/std::nullopt)); + output_arrays = std::move(result.outputs); + // options.fill_status is only supposed to be true when the computation has + // tokens. + if (options.fill_status) { + result_status = result.status; + if (returned_futures.has_value()) { + returned_futures->resize(num_computations, std::move(result.status)); + } + } + } + + // TODO(b/240696624): Although the PjRt interface require `returned_futures` + // to be resized correctly if it is not nullopt, some implementation does not + // implement this. So we have to check whether returned_futures is empty. + // Remove this check once the implementation is fixed. + auto py_sharded_token = returned_futures.has_value() + ? PyShardedToken(std::move(*returned_futures)) + : PyShardedToken(); + + return PyExecuteResults(client, std::move(output_arrays), num_computations, + std::move(py_sharded_token), result_status); +} + +} // namespace + +PyExecuteResults::PyExecuteResults(const nb_class_ptr &client, + std::vector ifrt_arrays, + int num_computations, PyShardedToken token, + PjRtFuture<> result_status) + : client_(client), + ifrt_arrays_(std::move(ifrt_arrays)), + num_computations_(num_computations), + token_(std::move(token)), + result_status_(std::move(result_status)) {} + +void PyExecuteResults::CheckNotDisassembled() const { + if (is_exploded_) { + throw nb::value_error("ExecuteResults already exploded."); + } +} + +std::vector PyExecuteResults::Consume() { + CheckNotDisassembled(); + is_exploded_ = true; + return std::move(ifrt_arrays_); +} + +PyShardedToken PyExecuteResults::ConsumeToken() { + if (token_consumed_) { + throw nb::value_error("ExecuteResults token already consumed."); + } + token_consumed_ = true; + return std::move(token_); +} + +std::vector> +PyExecuteResults::DisassembleIntoSingleDeviceArrays() { + std::vector> outputs; + PopulateExecuteShardedResults( + client_, Consume(), + result_status_.IsValid() ? result_status_ : PjRtFuture<>(), + num_computations_, outputs); + return outputs; +} + +std::vector> +PyExecuteResults::DisassemblePrefixIntoSingleDeviceArrays(size_t n) { + CheckNotDisassembled(); + if (n > ifrt_arrays_.size()) { + throw nb::value_error( + absl::StrCat("In DisassemblePrefixIntoSingleDeviceArrays: ", n, " > ", + ifrt_arrays_.size()) + .c_str()); + } + std::vector ifrt_arrays; + ifrt_arrays.reserve(ifrt_arrays_.size() - n); + for (size_t i = n; i < ifrt_arrays_.size(); ++i) { + ifrt_arrays.push_back(std::move(ifrt_arrays_[i])); + } + ifrt_arrays_.erase(ifrt_arrays_.begin() + n, ifrt_arrays_.end()); + std::swap(ifrt_arrays_, ifrt_arrays); + std::vector> outputs; + PopulateExecuteShardedResults( + client_, std::move(ifrt_arrays), + result_status_.IsValid() ? result_status_ : PjRtFuture<>(), + num_computations_, outputs); + return outputs; +} + +std::vector PyExecuteResults::ConsumeWithHandlers( + std::vector> + out_handlers) { + std::vector outputs; + auto ifrt_arrays = Consume(); + auto traceback = Traceback::Get(); + int num_output_buffers = ifrt_arrays.size(); + outputs.reserve(num_output_buffers); + if (out_handlers.size() != num_output_buffers) { + throw nb::value_error( + absl::StrCat("Mismatch between out_handlers and num_results: ", + out_handlers.size(), " vs ", num_output_buffers) + .c_str()); + } + for (int buffer_id = 0; buffer_id < num_output_buffers; ++buffer_id) { + auto &handler = out_handlers[buffer_id]; + if (std::holds_alternative(handler)) { + outputs.push_back(std::get(handler)->Call( + client_, std::move(ifrt_arrays[buffer_id]), + result_status_.IsValid() ? result_status_ : PjRtFuture<>())); + } else { + tsl::profiler::TraceMe traceme("ConsumeWithHandlers fallback."); + auto disassembled_arrays = + ifrt_arrays[buffer_id]->DisassembleIntoSingleDeviceArrays( + ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); + TF_CHECK_OK(disassembled_arrays.status()); + nb::list bufs = + nb::steal(PyList_New(disassembled_arrays->size())); + int i = 0; + for (auto &disassembled_array : *disassembled_arrays) { + nb::object array = PyArray::MakeFromSingleDeviceArray( + client_, traceback, std::move(disassembled_array), false, true, + result_status_.IsValid() ? result_status_ : PjRtFuture<>()); + PyList_SET_ITEM(bufs.ptr(), i, array.release().ptr()); + ++i; + } + outputs.push_back(std::get(handler)(std::move(bufs))); + } + } + return outputs; +} + +absl::StatusOr PyLoadedExecutable::ExecuteSharded( + std::vector args, bool with_tokens) { + xla::ifrt::ExecuteOptions options = options_; + options.launch_id = GetNextLaunchId(); + options.fill_status = with_tokens; + options.execution_stream_id = tsl::Env::Default()->GetCurrentThreadId(); + std::optional>> returned_futures; + if (with_tokens) { + returned_futures.emplace(); + } + absl::Span span_args = args; + return ExecuteShardedOnLocalDevicesInternal(options, client_, + ifrt_loaded_executable_.get(), + span_args, returned_futures); +} + +absl::StatusOr>> +PyLoadedExecutable::HloModules() const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetHloModules(); +} + +absl::StatusOr>> +PyLoadedExecutable::GetOutputMemoryKinds() const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetOutputMemoryKinds(); +} + +absl::StatusOr>> +PyLoadedExecutable::GetParameterLayouts() const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetParameterLayouts(); +} + +absl::StatusOr>> +PyLoadedExecutable::GetOutputLayouts() const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetOutputLayouts(); +} + +std::optional> +PyLoadedExecutable::GetParameterShardings() const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetParameterShardings(); +} + +std::optional> PyLoadedExecutable::GetOutputShardings() + const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetOutputShardings(); +} + +int32_t PyLoadedExecutable::GetNextLaunchId() { + return absl::bit_cast( + next_launch_id_.fetch_add(1, std::memory_order_relaxed)); +} + +void PyLoadedExecutable::KeepAlive(nb::object obj) { + keepalives_.push_back(std::move(obj)); +} + +} // namespace xla diff --git a/tests/ci_clangformat/py_executable.h b/tests/ci_clangformat/py_executable.h new file mode 100644 index 0000000..d9537b4 --- /dev/null +++ b/tests/ci_clangformat/py_executable.h @@ -0,0 +1,246 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_EXECUTABLE_H_ +#define JAXLIB_PY_EXECUTABLE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_client.h" +#include "jaxlib/traceback.h" +#include "nanobind/nanobind.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/pjrt_ifrt/pjrt_executable.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/status.h" +#include "xla/xla_data.pb.h" +#include "llvm/Support/Casting.h" + +namespace xla { + +class PyToken { + public: + PyToken() = default; + explicit PyToken(PjRtFuture<> future) : future_(std::move(future)) {} + + static PyToken ReadyPyToken() { + return PyToken(PjRtFuture<>(absl::OkStatus())); + } + + absl::Status Await(); + + private: + PjRtFuture<> future_; +}; + +// PyShardedToken contains a PyToken for each device's execution. +class PyShardedToken { + public: + // Default construction creates a always-ready token. + PyShardedToken() = default; + explicit PyShardedToken(std::vector> futures) + : futures_(std::move(futures)) {} + + PyToken GetPyToken(int device_id) const { + if (futures_.empty()) return PyToken::ReadyPyToken(); + return PyToken(futures_.at(device_id)); + } + + absl::Status Await(); + + private: + std::vector> futures_; +}; + +class PyExecuteResults { + public: + PyExecuteResults(const nb_class_ptr &client, + std::vector ifrt_arrays, + int num_computations, PyShardedToken token, + PjRtFuture<> result_status = PjRtFuture<>()); + + std::vector> DisassembleIntoSingleDeviceArrays(); + + std::vector> DisassemblePrefixIntoSingleDeviceArrays( + size_t n); + + std::vector ConsumeWithHandlers( + std::vector> + out_handlers); + + std::vector Consume(); + + PyShardedToken ConsumeToken(); + + size_t Size() const { + CheckNotDisassembled(); + return ifrt_arrays_.size(); + } + + void CheckNotDisassembled() const; + + private: + bool is_exploded_ = false; + bool token_consumed_ = false; + nb_class_ptr client_; + std::vector ifrt_arrays_; + int num_computations_; + PyShardedToken token_; + // Only set if the computation has tokens. + PjRtFuture<> result_status_; +}; + +using ExecuteShardedArg = std::variant>; + +// Python wrapper around PjRtExecutable. We use a wrapper class: +// a) to keep the PyClient alive via a std::shared_ptr<> +// b) to add Python-specific functionality. +class PyLoadedExecutable { + public: + PyLoadedExecutable(nb_class_ptr client, + ifrt::LoadedExecutableRef ifrt_loaded_executable, + std::optional traceback, + std::optional fingerprint); + ~PyLoadedExecutable(); + + nb_class_ptr client() const { return client_; } + ifrt::LoadedExecutable *ifrt_loaded_executable() const { + return ifrt_loaded_executable_.get(); + } + + ifrt::LoadedExecutableRef shared_ifrt_loaded_executable() { + return ifrt_loaded_executable_; + } + + std::vector> AddressableDevices() const; + + int64_t SizeOfGeneratedCodeInBytes() const { + return ifrt_loaded_executable_->SizeOfGeneratedCodeInBytes(); + } + + absl::StatusOr GetCompiledMemoryStats() const { + nanobind::gil_scoped_release scope; + return ifrt_loaded_executable_->GetCompiledMemoryStats(); + } + + absl::StatusOr GetCostAnalysis() const { + return ifrt_loaded_executable_->GetCostAnalysis(); + } + + // Takes args indexed by argid then deviceid, transposes them, and passes to + // PjRtExecutable::Execute. The result is similarly transposed back into the + // argid,deviceid format. + // args is [num_args x num_devices]. + absl::StatusOr ExecuteSharded( + std::vector args, bool with_tokens); + + absl::StatusOr>> HloModules() const; + + absl::StatusOr>> + GetOutputMemoryKinds() const; + + absl::StatusOr>> + GetParameterLayouts() const; + + absl::StatusOr>> + GetOutputLayouts() const; + + std::optional> GetParameterShardings() const; + + std::optional> GetOutputShardings() const; + + const std::optional &traceback() { return traceback_; } + + ifrt::LoadedExecutable *ifrt_executable() const { + return ifrt_loaded_executable_.get(); + } + + // Short-term escape hatch to get PjRtLoadedExecutable from PyExecutable. + // TODO(hyeontaek): Migrate all users of this method to be agnostic of PjRt. + std::shared_ptr shared_ptr_pjrt_executable() { + auto *exec = llvm::dyn_cast_or_null( + ifrt_loaded_executable_.get()); + if (exec == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return exec->shared_ptr_pjrt_loaded_executable(); + } + + // Returns a template of execute options to pass to + // `ifrt_executable()->Execute()`. Note that the caller may need to override + // some options such as `launch_id` that change at each execution. + const ifrt::ExecuteOptions &options() const { return options_; } + + // Returns a unique launch ID to use for the next execution. + int32_t GetNextLaunchId(); + + const std::optional &fingerprint() const { return fingerprint_; } + + // Keep `obj` alive as long as PyLoadedExecutable. + void KeepAlive(nanobind::object obj); + + private: + friend class PyClient; + + nb_class_ptr client_; + ifrt::LoadedExecutableRef ifrt_loaded_executable_; + std::optional traceback_; + + // Identical executables (i.e. representing the same program) will have the + // same fingerprint. nullopt on platforms or executables where fingerprints + // aren't implemented. + std::optional fingerprint_; + + // Launch ID to use for the next execution. + std::atomic next_launch_id_; + + // The options to pass to `executable_.Execute`. + ifrt::ExecuteOptions options_; + + // Python objects to keep alive as requested by user. + std::vector keepalives_; + + // Doubly-linked list of all executables known to the client. Protected by the + // GIL. + PyLoadedExecutable *next_; + PyLoadedExecutable *prev_; +}; + +} // namespace xla + +#endif // JAXLIB_PY_EXECUTABLE_H_ diff --git a/tests/ci_clangformat/py_host_callback.cc b/tests/ci_clangformat/py_host_callback.cc new file mode 100644 index 0000000..5c7fa3d --- /dev/null +++ b/tests/ci_clangformat/py_host_callback.cc @@ -0,0 +1,259 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_host_callback.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "jaxlib/callback.h" +#include "jaxlib/py_host_callback.pb.h" +#include "jaxlib/python_ref_manager.h" +#include "nanobind/nanobind.h" +#include "xla/layout_util.h" +#include "xla/pjrt/host_callback.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/pjrt_ifrt/pjrt_host_callback.h" +#include "xla/python/pjrt_ifrt/xla_host_callback.pb.h" +#include "xla/python/types.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "llvm/Support/ExtensibleRTTI.h" + +namespace nb = nanobind; + +namespace xla { + +char PyFfiLoadedHostCallback::ID = 0; +char PyHostSendAndRecvLoadedHostCallback::ID = 0; + +namespace { + +absl::StatusOr> CreateCallbackArgs( + absl::Span operand_shapes) { + std::vector callback_args(operand_shapes.size()); + for (int i = 0; i < operand_shapes.size(); ++i) { + Shape shape = operand_shapes[i]; + + if (shape.IsArray()) { + Shape layout = + (shape.has_layout() ? shape + : LayoutUtil::GetWithDefaultLayout(shape)); + callback_args[i].dims.resize(shape.dimensions_size()); + absl::c_copy(shape.dimensions(), callback_args[i].dims.begin()); + callback_args[i].strides = ByteStridesForShape(layout); + callback_args[i].type = shape.element_type(); + callback_args[i].size_in_bytes = ShapeUtil::ByteSizeOf(layout); + TF_ASSIGN_OR_RETURN(callback_args[i].dtype, + PrimitiveTypeToNbDtype(shape.element_type())); + } else if (shape.IsToken()) { + callback_args[i].type = TOKEN; + } else { + return InvalidArgument( + "Only array and token arguments to Python callbacks are supported, " + "got %s", + shape.ToString()); + } + } + return callback_args; +} + +absl::StatusOr> CreateCallbackResults( + absl::Span result_shapes) { + std::vector callback_results(result_shapes.size()); + for (int i = 0; i < result_shapes.size(); ++i) { + if (result_shapes[i].IsArray()) { + const Shape &shape = + result_shapes[i].has_layout() + ? result_shapes[i] + : LayoutUtil::GetWithDefaultLayout(result_shapes[i]); + callback_results[i].expected_dims.resize(shape.dimensions_size()); + absl::c_copy(shape.dimensions(), + callback_results[i].expected_dims.begin()); + callback_results[i].expected_strides = ByteStridesForShape(shape); + callback_results[i].type = shape.element_type(); + callback_results[i].size_in_bytes = ShapeUtil::ByteSizeOf(shape); + callback_results[i].reversed_layout.resize(shape.dimensions_size()); + absl::c_reverse_copy(shape.layout().minor_to_major(), + callback_results[i].reversed_layout.begin()); + } else if (result_shapes[i].IsToken()) { + callback_results[i].type = TOKEN; + } else { + return InvalidArgument( + "Only array and token return values from Python callbacks are " + "supported, got %s", + result_shapes[i].ToString()); + } + } + return callback_results; +} + +} // namespace + +PyFfiLoadedHostCallback::~PyFfiLoadedHostCallback() { + // The destructor may be called without GIL held. In that case, we defer it + // to GlobalPyRefManager. + std::vector objects; + objects.push_back(std::move(callable_)); + GlobalPyRefManager()->AddGarbage(absl::MakeSpan(objects)); +} + +absl::StatusOr> +PyHostSendAndRecvLoadedHostCallback::Create( + ifrt::Client *ifrt_client, nb::callable callable, + absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, nb::callable serializer) { + TF_ASSIGN_OR_RETURN(auto callback_args, CreateCallbackArgs(operand_shapes)); + TF_ASSIGN_OR_RETURN(auto callback_results, + CreateCallbackResults(result_shapes)); + + // `callable` will be destroyed safely with `PythonRefManager` when + // `CpuCallback` is destroyed. + auto cpu_callback = + std::make_shared(callable, callback_args, callback_results); + + auto host_callback = std::make_unique(); + + auto assign_arg_info = [](absl::Span shapes, + absl::Span channel_ids, + std::vector &arg_infos) { + DCHECK_EQ(shapes.size(), channel_ids.size()); + arg_infos.reserve(shapes.size()); + for (int i = 0; i < shapes.size(); ++i) { + HostCallbackArgInfo host_callback_arg_info; + host_callback_arg_info.channel_id = channel_ids[i]; + const auto &shape = shapes[i]; + Shape layout = + (shape.has_layout() ? shape + : LayoutUtil::GetWithDefaultLayout(shape)); + host_callback_arg_info.shape = layout; + arg_infos.push_back(std::move(host_callback_arg_info)); + } + }; + + assign_arg_info(operand_shapes, send_channel_ids, host_callback->operands); + assign_arg_info(result_shapes, recv_channel_ids, host_callback->results); + + host_callback->callback = [cpu_callback = std::move(cpu_callback)]( + void **outputs, void **inputs) { + return cpu_callback->PrepareAndCall(outputs, inputs); + }; + return tsl::RCReference( + tsl::MakeRef( + ifrt_client, std::move(host_callback), callable, operand_shapes, + result_shapes, send_channel_ids, recv_channel_ids, + std::move(serializer))); +} + +PyHostSendAndRecvLoadedHostCallback::PyHostSendAndRecvLoadedHostCallback( + ifrt::Client *ifrt_client, + std::unique_ptr xla_host_callback, nb::callable callable, + absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, nb::callable serializer) + : llvm::RTTIExtends( + ifrt_client, std::move(xla_host_callback)), + callable_(std::move(callable)), + operand_shapes_(operand_shapes.begin(), operand_shapes.end()), + result_shapes_(result_shapes.begin(), result_shapes.end()), + send_channel_ids_(send_channel_ids.begin(), send_channel_ids.end()), + recv_channel_ids_(recv_channel_ids.begin(), recv_channel_ids.end()), + serializer_(serializer) {} + +PyHostSendAndRecvLoadedHostCallback::~PyHostSendAndRecvLoadedHostCallback() { + GlobalPyRefManager()->AddGarbage( + absl::MakeSpan(static_cast(&callable_), 1)); + GlobalPyRefManager()->AddGarbage( + absl::MakeSpan(static_cast(&serializer_), 1)); +} + +absl::StatusOr PyHostSendAndRecvLoadedHostCallback::Serialize() + const { + if (serializer_.is_none()) { + return InvalidArgument( + "Host callback cannot be serialized because serializer was not " + "provided by JAX"); + } + ifrt::XlaHostCallbackProto xla_host_callback_proto; + + TF_RET_CHECK(operand_shapes_.size() == send_channel_ids_.size()); + for (int i = 0; i < operand_shapes_.size(); ++i) { + ifrt::XlaHostCallbackProto::ArgInfo *const operand = + xla_host_callback_proto.add_operands(); + operand->set_channel_id(send_channel_ids_[i]); + *operand->mutable_shape() = operand_shapes_[i].ToProto(); + } + + TF_RET_CHECK(result_shapes_.size() == recv_channel_ids_.size()); + for (int i = 0; i < result_shapes_.size(); ++i) { + ifrt::XlaHostCallbackProto::ArgInfo *const result = + xla_host_callback_proto.add_results(); + result->set_channel_id(recv_channel_ids_[i]); + *result->mutable_shape() = result_shapes_[i].ToProto(); + } + + std::string callable; + { + nb::gil_scoped_acquire gil_acquire; + try { + nb::bytes bytes = nb::cast(serializer_(callable_)); + callable = std::string(bytes.c_str(), bytes.size()); + } catch (const nb::python_error &e) { + return absl::InternalError(absl::StrCat( + "Unable to pickle the host_callback callable: ", e.what())); + } catch (const std::exception &e) { + std::exception_ptr p = std::current_exception(); + return absl::InternalError(absl::StrCat( + "Exception while pickling the host_callback callable: ", e.what())); + } catch (...) { + // Ensure to avoid leaking any exception because this method could have + // been called outside of a Python context where C++ exceptions are not + // necessarily enabled. + return absl::InternalError( + "Unknown exception while pickling the host_callback callable."); + } + } + PyHostCallbackProto py_host_callback_proto; + py_host_callback_proto.set_callable(std::move(callable)); + if (!xla_host_callback_proto.mutable_serialized_callback()->PackFrom( + py_host_callback_proto)) { + return absl::InternalError("Could not serialize a Python host callback"); + } + xla_host_callback_proto.set_use_major_to_minor_data_layout_for_callbacks( + true); + return xla_host_callback_proto.SerializeAsString(); +} + +} // namespace xla diff --git a/tests/ci_clangformat/py_host_callback.h b/tests/ci_clangformat/py_host_callback.h new file mode 100644 index 0000000..0bbdbb9 --- /dev/null +++ b/tests/ci_clangformat/py_host_callback.h @@ -0,0 +1,119 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_HOST_CALLBACK_H_ +#define JAXLIB_PY_HOST_CALLBACK_H_ + +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "xla/pjrt/host_callback.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/pjrt_ifrt/pjrt_host_callback.h" +#include "xla/shape.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/util.h" +#include "llvm/Support/ExtensibleRTTI.h" + +namespace xla { + +using PyLoadedHostCallback = ::xla::ifrt::LoadedHostCallback; + +class PyFfiLoadedHostCallback final + : public llvm::RTTIExtends { + public: + PyFfiLoadedHostCallback(ifrt::Client *ifrt_client, + nanobind::callable callable) + : llvm::RTTIExtends(ifrt_client, + callable.ptr()), + callable_(std::move(callable)) {} + ~PyFfiLoadedHostCallback() override; + + ifrt::Client *client() const override { return ifrt_client_; } + absl::StatusOr Serialize() const override { + return Unimplemented( + "PyFfiLoadedHostCallback::Serialize() is not supported"); + }; + + static char ID; // NOLINT + + private: + ifrt::Client *ifrt_client_; + nanobind::callable callable_; +}; + +// `PyHostSendAndRecvLoadedHostCallback` implements a Python host callback that +// uses XLA host send and recv. This object should be passed to the compiler +// when creating `xla::ifrt::LoadedExecutable`. +// +// Serialization is supported if the Python host callback using the +// `cloudpickle` third-party library. +// +// TODO(hyeontaek): Update the comment ("compiler" to "client") after splitting +// compilation and loading. +class PyHostSendAndRecvLoadedHostCallback final + : public llvm::RTTIExtends { + public: + static absl::StatusOr> + Create(ifrt::Client *ifrt_client, nanobind::callable callable, + absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, + nanobind::callable serializer); + + // PjRtLoadedHostCallback implementation. + + ~PyHostSendAndRecvLoadedHostCallback() override; + + absl::StatusOr Serialize() const override; + + static char ID; // NOLINT + + private: + PyHostSendAndRecvLoadedHostCallback( + ifrt::Client *ifrt_client, + std::unique_ptr xla_host_callback, + nanobind::callable callable, absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, + nanobind::callable serializer); + + template + friend tsl::RCReference tsl::MakeRef(Args &&...args); + + // Retained arguments for host callback serialization. + nanobind::callable callable_; + std::vector operand_shapes_; + std::vector result_shapes_; + std::vector send_channel_ids_; + std::vector recv_channel_ids_; + nanobind::callable serializer_; +}; + +} // namespace xla + +#endif // JAXLIB_PY_HOST_CALLBACK_H_ diff --git a/tests/ci_clangformat/py_memory_space.cc b/tests/ci_clangformat/py_memory_space.cc new file mode 100644 index 0000000..679beda --- /dev/null +++ b/tests/ci_clangformat/py_memory_space.cc @@ -0,0 +1,102 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_memory_space.h" + +#include + +#include + +#include "absl/strings/string_view.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "xla/python/ifrt/device.h" + +namespace nb = ::nanobind; + +namespace xla { + +PyMemorySpace::PyMemorySpace(nb_class_ptr client, + ifrt::Memory *memory) + : client_(std::move(client)), memory_(memory) {} + +int PyMemorySpace::process_index() const { return client_->process_index(); } + +absl::string_view PyMemorySpace::platform() const { + // TODO(phawkins): this is a temporary backwards + // compatibility shim. We changed the name PJRT + // reports for GPU platforms to "cuda" or "rocm", + // but we haven't yet updated JAX clients that + // expect "gpu". Migrate users and remove this + // code. + if (client_->platform_name() == "cuda" || + client_->platform_name() == "rocm") { + return absl::string_view("gpu"); + } else { + return client_->platform_name(); + } +} + +absl::string_view PyMemorySpace::kind() const { + return *memory_->Kind().memory_kind(); +} + +absl::string_view PyMemorySpace::Str() const { return memory_->DebugString(); } + +absl::string_view PyMemorySpace::Repr() const { return memory_->ToString(); } + +nb::list PyMemorySpace::AddressableByDevices() const { + nb::list devices; + for (ifrt::Device *device : memory_->Devices()) { + devices.append(client_->GetPyDevice(device)); + } + return devices; +} + +/* static */ int PyMemorySpace::tp_traverse(PyObject *self, visitproc visit, + void *arg) { + PyMemorySpace *d = nb::inst_ptr(self); + Py_VISIT(d->client().ptr()); + return 0; +} + +/* static */ int PyMemorySpace::tp_clear(PyObject *self) { + PyMemorySpace *d = nb::inst_ptr(self); + nb_class_ptr client; + std::swap(client, d->client_); + return 0; +} + +PyType_Slot PyMemorySpace::slots_[] = { + {Py_tp_traverse, (void *)PyMemorySpace::tp_traverse}, + {Py_tp_clear, (void *)PyMemorySpace::tp_clear}, + {0, nullptr}, +}; + +/* static */ void PyMemorySpace::RegisterPythonType(nb::module_ &m) { + nb::class_ device(m, "Memory", + nb::type_slots(PyMemorySpace::slots_)); + device.def_prop_ro("process_index", &PyMemorySpace::process_index) + .def_prop_ro("platform", &PyMemorySpace::platform) + .def_prop_ro("kind", &PyMemorySpace::kind) + .def("__str__", &PyMemorySpace::Str) + .def("__repr__", &PyMemorySpace::Repr) + .def("addressable_by_devices", &PyMemorySpace::AddressableByDevices, + "Returns devices that can address this memory."); +} + +} // namespace xla diff --git a/tests/ci_clangformat/py_memory_space.h b/tests/ci_clangformat/py_memory_space.h new file mode 100644 index 0000000..8cdfcf2 --- /dev/null +++ b/tests/ci_clangformat/py_memory_space.h @@ -0,0 +1,65 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_MEMORY_SPACE_H_ +#define JAXLIB_PY_MEMORY_SPACE_H_ + +#include + +#include "absl/strings/string_view.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "nanobind/nanobind.h" +#include "xla/python/ifrt/memory.h" + +namespace xla { + +class PyMemorySpace { + public: + PyMemorySpace(nb_class_ptr client, ifrt::Memory *memory_space); + + // Memory spaces are compared using Python object identity, so we don't allow + // them to be copied or moved. + PyMemorySpace(const PyMemorySpace &) = delete; + PyMemorySpace(PyMemorySpace &&) = delete; + PyMemorySpace &operator=(const PyMemorySpace &) = delete; + PyMemorySpace &operator=(PyMemorySpace &&) = delete; + + const nb_class_ptr &client() const { return client_; } + ifrt::Memory *memory_space() const { return memory_; } + + int process_index() const; + absl::string_view platform() const; + absl::string_view kind() const; + + absl::string_view Str() const; + absl::string_view Repr() const; + + nanobind::list AddressableByDevices() const; + + static void RegisterPythonType(nanobind::module_ &m); + + private: + static int tp_traverse(PyObject *self, visitproc visit, void *arg); + static int tp_clear(PyObject *self); + static PyType_Slot slots_[]; + + nb_class_ptr client_; + ifrt::Memory *memory_; +}; + +} // namespace xla + +#endif // JAXLIB_PY_MEMORY_SPACE_H_ diff --git a/tests/ci_clangformat/py_program.cc b/tests/ci_clangformat/py_program.cc new file mode 100644 index 0000000..9ec76a6 --- /dev/null +++ b/tests/ci_clangformat/py_program.cc @@ -0,0 +1,301 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_program.h" + +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_device.h" +#include "jaxlib/py_device_list.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/sharding.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/array_spec.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/custom_call_program.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/hlo/hlo_program.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/plugin_program.h" +#include "xla/python/ifrt/program.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "xla/python/pjrt_ifrt/xla_sharding.h" +#include "xla/python/types.h" +#include "xla/python/version.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla { + +namespace nb = ::nanobind; + +namespace { + +// Gets `ifrt::DeviceList` from a sequence of JAX devices. +absl::StatusOr GetDeviceList(nb::sequence devices) { + ifrt::DeviceListRef ifrt_device_list; + if (devices.type().is(jax::PyDeviceList::type())) { + return nb::cast(devices)->ifrt_device_list(); + } else { + auto py_devices = nb::cast>>(devices); + if (py_devices.empty()) { + return absl::InvalidArgumentError( + "Colocated Python program requires at least one device"); + } + absl::InlinedVector ifrt_devices; + ifrt_devices.reserve(py_devices.size()); + for (const nb_class_ptr &py_device : py_devices) { + ifrt_devices.push_back(py_device->device()); + } + return py_devices.front()->client()->ifrt_client()->MakeDeviceList( + ifrt_devices); + } +} + +// Gets `xla::HloSharding` from a JAX Sharding. +xla::HloSharding GetXlaHloSharding(nb::handle sharding, + int64_t num_dimensions) { + if (sharding.type().is(jax::GSPMDSharding::type())) { + return nb::cast(sharding)->hlo_sharding(); + } else { + return nb::cast( + sharding.attr("_to_xla_hlo_sharding")(num_dimensions)); + } +} + +// Gets `ifrt::DeviceList` from a JAX Sharding. +absl::StatusOr GetIfrtDeviceList(nb::handle sharding) { + if (sharding.type().is(jax::NamedSharding::type())) { + TF_ASSIGN_OR_RETURN( + auto ns_device_list, + nb::cast(sharding)->internal_device_list()); + return ns_device_list->ifrt_device_list(); + } else if (sharding.type().is(jax::SingleDeviceSharding::type())) { + return nb::cast(sharding) + ->internal_device_list() + ->ifrt_device_list(); + } else if (sharding.type().is(jax::PmapSharding::type())) { + return nb::cast(sharding) + ->internal_device_list() + ->ifrt_device_list(); + } else if (sharding.type().is(jax::GSPMDSharding::type())) { + return nb::cast(sharding) + ->internal_device_list() + ->ifrt_device_list(); + } else { + return nb::cast( + sharding.attr("_internal_device_list")) + ->ifrt_device_list(); + } +} + +// Gets `ifrt::MemoryKind` from a JAX Sharding. +ifrt::MemoryKind GetIfrtMemoryKind(nb::handle sharding) { + auto memory_kind = sharding.attr("memory_kind"); + if (memory_kind.is_none()) { + return ifrt::MemoryKind(); + } else { + return ifrt::MemoryKind(nb::cast(memory_kind)); + } +} + +// Makes `ifrt::Sharding` from a JAX Sharding. It requires the number of shape +// dimensions, which may become necessary when building an HLO sharding. +absl::StatusOr GetIfrtSharding(nb::handle sharding, + int64_t num_dimensions) { + auto ifrt_memory_kind = GetIfrtMemoryKind(sharding); + ifrt::ShardingRef ifrt_sharding; + if (sharding.type().is(jax::SingleDeviceSharding::type())) { + TF_ASSIGN_OR_RETURN(auto ifrt_device_list, + nb::cast(sharding) + ->internal_device_list() + ->ifrt_device_list()); + return ifrt::SingleDeviceSharding::Create( + ifrt_device_list->devices().front(), ifrt_memory_kind); + } else { + TF_ASSIGN_OR_RETURN(auto ifrt_device_list, GetIfrtDeviceList(sharding)); + auto xla_hlo_sharding = GetXlaHloSharding(sharding, num_dimensions); + return ifrt::HloSharding::Create(std::move(ifrt_device_list), + ifrt_memory_kind, + std::move(xla_hlo_sharding)); + } +} + +// Gets `ifrt::ArraySpec`s from a sequence of JAX avals (e.g., +// `jax.ShapeDtypeStruct`). +absl::StatusOr> GetIfrtArraySpecs( + nb::sequence avals) { + std::vector ifrt_array_specs; + ifrt_array_specs.reserve(nb::len(avals)); + for (nb::handle aval : avals) { + ifrt::Shape ifrt_shape(nb::cast>(aval.attr("shape"))); + TF_ASSIGN_OR_RETURN( + auto ifrt_dtype, + DtypeToIfRtDType(nb::cast(aval.attr("dtype")))); + TF_ASSIGN_OR_RETURN( + auto ifrt_sharding, + GetIfrtSharding(aval.attr("sharding"), ifrt_shape.dims().size())); + ifrt_array_specs.push_back(ifrt::ArraySpec{ + ifrt_dtype, std::move(ifrt_shape), std::move(ifrt_sharding)}); + } + return ifrt_array_specs; +} + +absl::StatusOr> MakePluginProgramFromString( + std::string data) { + auto plugin_program = std::make_unique(); + plugin_program->data = std::move(data); + return plugin_program; +} + +absl::StatusOr> MakePluginProgramFromBytes( + nb::bytes data) { + auto plugin_program = std::make_unique(); + plugin_program->data = std::string(data.c_str(), data.size()); + return plugin_program; +} + +absl::StatusOr> +MakeColocatedPythonCompileOptions() { + return std::make_unique(); +} + +absl::StatusOr> +MakePluginCompileOptions() { + return std::make_unique(); +} + +absl::StatusOr> MakeHloProgram( + absl::string_view mlir_module) { + auto context = std::make_unique(); + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, *context)); + return std::make_unique(std::move(context), + std::move(module)); +} + +absl::StatusOr> MakeHloProgramFromString( + std::string mlir_module) { + return MakeHloProgram(mlir_module); +} + +absl::StatusOr> MakeHloProgramFromBytes( + nb::bytes mlir_module) { + return MakeHloProgram( + absl::string_view(mlir_module.c_str(), mlir_module.size())); +} + +absl::StatusOr> MakeXlaCompileOptions( + CompileOptions options, jax::PyDeviceList &py_executable_devices, + std::vector host_callbacks) { + std::vector> + ifrt_loaded_host_callbacks; + ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); + // Extract `ifrt::LoadedHostCallback`s from host callback capsules that were + // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()` or + // `PyClient::GetEmitPythonCallbackDescriptor()`. + for (auto &host_callback : host_callbacks) { + ifrt_loaded_host_callbacks.push_back(tsl::FormRef( + static_cast(host_callback.data()))); + } +#if JAX_IFRT_VERSION_NUMBER >= 6 + TF_ASSIGN_OR_RETURN(ifrt::DeviceListRef executable_devices, + py_executable_devices.ifrt_device_list()); + return std::make_unique( + std::move(options), std::move(executable_devices), + std::move(ifrt_loaded_host_callbacks)); +#else + return std::make_unique( + std::move(options), std::move(ifrt_loaded_host_callbacks)); +#endif +} + +constexpr absl::string_view kColocatedPythonProgramType = + "jax_colocated_python_v0.0.1"; + +absl::StatusOr> MakeColocatedPythonProgram( + std::string name, nb::bytes picked_function, nb::sequence devices, + nb::sequence input_avals, nb::sequence output_avals) { + auto ifrt_serialized_program_text = absl::MakeCordFromExternal( + absl::string_view(reinterpret_cast(picked_function.data()), + picked_function.size()), + /*releaser=*/[picked_function](absl::string_view) mutable { + GlobalPyRefManager()->AddGarbage(std::move(picked_function)); + }); + TF_ASSIGN_OR_RETURN(auto ifrt_device_list, GetDeviceList(devices)); + TF_ASSIGN_OR_RETURN(auto ifrt_input_specs, GetIfrtArraySpecs(input_avals)); + TF_ASSIGN_OR_RETURN(auto ifrt_output_specs, GetIfrtArraySpecs(output_avals)); + return std::make_unique( + std::string(kColocatedPythonProgramType), std::move(name), + std::move(ifrt_serialized_program_text), std::move(ifrt_device_list), + std::move(ifrt_input_specs), std::move(ifrt_output_specs)); +} + +} // namespace + +void BuildIfrtProgramsSubmodule(nanobind::module_ &m) { + auto sub_module = m.def_submodule("ifrt_programs"); + nb::class_ ifrt_program_base_class(sub_module, "Program"); + nb::class_ ifrt_compile_options_base_class( + sub_module, "CompileOptions"); + sub_module + .def("make_hlo_program", ValueOrThrowWrapper(MakeHloProgramFromString), + nb::arg("mlir_module")) + .def("make_hlo_program", ValueOrThrowWrapper(MakeHloProgramFromBytes), + nb::arg("mlir_module")) + .def("make_colocated_python_program", + ValueOrThrowWrapper(MakeColocatedPythonProgram), nb::arg("name"), + nb::arg("pickled_function"), nb::arg("devices"), + nb::arg("input_avals"), nb::arg("output_avals")) + .def("make_plugin_program", + ValueOrThrowWrapper(MakePluginProgramFromString), nb::arg("data")) + .def("make_plugin_program", + ValueOrThrowWrapper(MakePluginProgramFromBytes), nb::arg("data")) + .def("make_xla_compile_options", + ValueOrThrowWrapper(MakeXlaCompileOptions), nb::arg("options"), + nb::arg("executable_devices"), nb::arg("host_callbacks")) + .def("make_colocated_python_compile_options", + ValueOrThrowWrapper(MakeColocatedPythonCompileOptions)) + .def("make_plugin_compile_options", + ValueOrThrowWrapper(MakePluginCompileOptions)); +} + +} // namespace xla diff --git a/tests/ci_clangformat/py_program.h b/tests/ci_clangformat/py_program.h new file mode 100644 index 0000000..9cfe64c --- /dev/null +++ b/tests/ci_clangformat/py_program.h @@ -0,0 +1,27 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_PROGRAM_H_ +#define JAXLIB_PY_PROGRAM_H_ + +#include "nanobind/nanobind.h" + +namespace xla { + +void BuildIfrtProgramsSubmodule(nanobind::module_ &m); + +} // namespace xla + +#endif // JAXLIB_PY_PROGRAM_H_ diff --git a/tests/ci_clangformat/py_socket_transfer.cc b/tests/ci_clangformat/py_socket_transfer.cc new file mode 100644 index 0000000..3c1aeff --- /dev/null +++ b/tests/ci_clangformat/py_socket_transfer.cc @@ -0,0 +1,420 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "jaxlib/py_socket_transfer.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/synchronization/mutex.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_client.h" +#include "jaxlib/to_ifrt_sharding.h" +#include "jaxlib/traceback.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/array.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "tsl/platform/casts.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/array_spec.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/pjrt_ifrt/pjrt_memory.h" +#include "xla/python/transfer/event_loop.h" +#include "xla/python/transfer/socket-server.h" +#include "xla/python/transfer/socket_bulk_transport.h" +#include "xla/python/transfer/streaming.h" +#include "xla/python/transfer/streaming_ifrt.h" +#include "xla/python/transfer/transfer_socket.pb.h" +#include "xla/python/types.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "llvm/Support/Casting.h" + +namespace aux { + +namespace nb = nanobind; + +absl::StatusOr MemorySpaceFromSharding( + const xla::ifrt::Sharding &sharding) { + if (sharding.devices()->devices().size() != 1) { + return xla::InvalidArgument( + "Can only convert SingleDeviceSharding to MemorySpace not %s", + sharding.DebugString()); + } + auto *device = sharding.devices()->devices()[0]; + if (sharding.memory_kind().memory_kind().has_value()) { + // Find `PjRtMemorySpace` that is associated with the sharding's device + // and matches the sharding's memory_kind. + xla::ifrt::Memory *memory = nullptr; + for (xla::ifrt::Memory *ms : device->Memories()) { + if (ms->Kind() == sharding.memory_kind()) { + memory = ms; + break; + } + } + if (memory == nullptr) { + return xla::InvalidArgument( + "Invalid memory kind: %s; available memory kinds: %s", + *sharding.memory_kind().memory_kind(), + absl::StrJoin(sharding.devices()->devices().front()->Memories(), ", ", + [](std::string *out, xla::ifrt::Memory *ms) { + absl::StrAppend(out, *ms->Kind().memory_kind()); + })); + } + return tensorflow::down_cast(memory) + ->pjrt_memory(); + } else { + if (!device->IsAddressable()) { + return xla::InvalidArgument( + "Cannot copy array to non-addressable device %s", + device->DebugString()); + } + return tensorflow::down_cast(device) + ->pjrt_device() + ->default_memory_space(); + } +} + +class IfrtArrayEntry : public PullTable::Entry { + public: + struct BufferRef { + xla::ifrt::ArrayRef arr; + xla::PjRtBuffer *buffer; + size_t buf_size; + }; + explicit IfrtArrayEntry(std::vector arrs, + std::shared_ptr state, + size_t xfer_size) + : arrs_(std::move(arrs)), state_(state), xfer_size_(xfer_size) {} + bool Handle(tsl::RCReference state, + const SocketTransferPullRequest &req, + size_t base_req_id) override { + for (uint64_t bid : req.buffer_ids()) { + auto req_id = base_req_id; + ++base_req_id; + for (size_t i = 0; i * xfer_size_ < arrs_[bid].buf_size; ++i) { + DmaCopyChunk blob = DmaCopyChunk::Make( + std::move(arrs_[bid].arr), arrs_[bid].buffer, bid, i * xfer_size_, + std::min(xfer_size_, arrs_[bid].buf_size - i * xfer_size_)); + bool is_largest = blob.size + blob.offset == arrs_[bid].buf_size; + state_->ScheduleCopy( + std::move(blob), [req_id, state, copier_state = state_, is_largest]( + PremappedCopierState *copier_state_ptr, + void *buf, const DmaCopyChunk &chunk) { + state->Send( + req_id, buf, chunk.offset, chunk.size, is_largest, + [copier_state, buf]() { copier_state->ReturnBuffer(buf); }); + }); + } + } + + num_consumed_bufs_ += req.buffer_ids().size(); + return num_consumed_bufs_ == arrs_.size(); + } + + private: + absl::Mutex mu_; + size_t num_consumed_bufs_ = 0; + std::vector arrs_; + std::shared_ptr state_; + size_t xfer_size_; +}; + +absl::StatusOr> CreatePullEntry( + const std::vector &arrs, + std::shared_ptr state, size_t xfer_size) { + std::vector refs; + for (auto &arr : arrs) { + auto *pjrt_arr = llvm::dyn_cast_or_null(arr.get()); + if (pjrt_arr == nullptr) { + return absl::InvalidArgumentError( + "Cannot remote transfer non-pjrt arrays."); + } + for (auto &pjrt_buf : pjrt_arr->pjrt_buffers()) { + TF_ASSIGN_OR_RETURN(size_t buf_size, pjrt_buf->GetOnDeviceSizeInBytes()); + refs.push_back({arr, pjrt_buf.get(), buf_size}); + } + } + return tsl::MakeRef(std::move(refs), state, xfer_size); +} + +class PyTransferServerConnection { + public: + explicit PyTransferServerConnection( + tsl::RCReference conn) + : conn_(std::move(conn)) {} + + void Pull(uint64_t uuid, std::vector buffer_ids, + std::vector> pull_dests) { + for (size_t i = 0; i < buffer_ids.size(); ++i) { + conn_->Pull(uuid, buffer_ids[i], std::move(pull_dests[i])); + } + } + + private: + tsl::RCReference conn_; +}; + +class PyTransferServer { + public: + PyTransferServer() = default; + absl::Status Start(xla::ifrt::Client *client, size_t max_num_parallel_copies, + size_t xfer_size, const SocketAddress &addr, + const std::vector &transport_addresses, + bool supports_pinned_allocator) { + std::shared_ptr factory; + if (transport_addresses.empty()) { + factory = BulkTransportFactory::CreateLocal(); + } else { + auto tmp = xla::ValueOrThrow( + AllocateAlignedMemory(xfer_size * max_num_parallel_copies)); + SlabAllocator uallocator(xla::ValueOrThrow(MapPjrtMemory( + client, tmp->data(), tmp->size(), tmp)), + xfer_size); + std::optional pinned_allocator; + if (supports_pinned_allocator) { + auto tmp = xla::ValueOrThrow( + AllocateNetworkPinnedMemory(xfer_size * max_num_parallel_copies)); + pinned_allocator.emplace(xla::ValueOrThrow(MapPjrtMemory( + client, tmp->data(), tmp->size(), tmp)), + xfer_size); + } + factory = xla::ValueOrThrow(CreateSocketBulkTransportFactory( + transport_addresses, pinned_allocator, uallocator)); + } + + server_ = std::make_shared(); + + TF_ASSIGN_OR_RETURN(auto mem, + AllocateAndMapPjrtMemory( + client, max_num_parallel_copies * xfer_size * 2)); + premapped_copier_ = std::make_shared( + mem, max_num_parallel_copies, xfer_size); + xfer_size_ = xfer_size; + return server_->Start(addr, factory); + } + std::string address() { return server_->addr().ToString(); } + + PyTransferServerConnection Connect(const std::string &saddr) { + return PyTransferServerConnection( + server_->Connect(xla::ValueOrThrow(SocketAddress::Parse(saddr)))); + } + + void AwaitPull(uint64_t uuid, const std::vector &arrs) { + server_->AwaitPull(uuid, xla::ValueOrThrow(CreatePullEntry( + arrs, premapped_copier_, xfer_size_))); + } + + size_t xfer_size() { return xfer_size_; } + + std::shared_ptr premapped_copier() { + return premapped_copier_; + } + + private: + std::shared_ptr server_; + std::shared_ptr premapped_copier_; + size_t xfer_size_; +}; + +absl::StatusOr ArraySpecFromShapeDtypeStruct( + nb::handle aval) { + TF_ASSIGN_OR_RETURN(xla::ifrt::DType dtype, + xla::DtypeToIfRtDType( + nb::borrow(aval.attr("dtype").ptr()))); + auto shape_dims = nb::cast>(aval.attr("shape")); + auto shape = xla::ifrt::Shape( + xla::ifrt::Shape::Dimensions(shape_dims.begin(), shape_dims.end())); + TF_ASSIGN_OR_RETURN(auto sharding, + xla::GetIfrtHloSharding(aval.attr("sharding"), shape)); + return xla::ifrt::ArraySpec{dtype, std::move(shape), std::move(sharding)}; +} + +struct BufferSource { + xla::ifrt::ArrayRef arr; + xla::PjRtBuffer *buffer; +}; + +struct CopyDests { + std::vector shape_specs; + xla::PjRtMemorySpace *memory_space; +}; + +void RegisterTransferServerTypes(nanobind::module_ &m) { + nb::class_(m, "TransferConnection") + .def("_pull_flat", [](PyTransferServerConnection &self, uint64_t uuid, + xla::nb_class_ptr py_client, + std::vector py_avals) { + auto *ifrt_client = llvm::dyn_cast_or_null( + py_client->ifrt_client()); + if (ifrt_client == nullptr) { + xla::ThrowIfError(absl::InvalidArgumentError( + "_pull_flat only supported on pjrt-ifrt clients.")); + } + + std::vector avals; + std::vector shardings; + shardings.reserve(py_avals.size()); + avals.reserve(py_avals.size()); + for (const auto &py_aval : py_avals) { + avals.push_back( + xla::ValueOrThrow(ArraySpecFromShapeDtypeStruct(py_aval))); + shardings.push_back(py_aval.attr("sharding")); + } + + std::vector dests; + std::vector> fetch_idxs; + absl::flat_hash_map mapping; + std::vector>> buffer_list; + + for (auto &aval : avals) { + std::vector> buf_list; + auto prim_type = + xla::ValueOrThrow(xla::ifrt::ToPrimitiveType(aval.dtype)); + auto shards = xla::ValueOrThrow(aval.sharding->Disassemble( + aval.shape, + xla::ifrt::SingleDeviceShardSemantics::kAddressableShards)); + buf_list.reserve(shards.size()); + for (auto &shard : shards) { + auto *mem_space = + xla::ValueOrThrow(MemorySpaceFromSharding(*shard.second)); + int dest_idx = + mapping.emplace(mem_space, static_cast(dests.size())) + .first->second; + if (dest_idx == dests.size()) { + dests.emplace_back(); + dests.back().memory_space = mem_space; + } + fetch_idxs.push_back( + {dest_idx, + static_cast(dests[dest_idx].shape_specs.size())}); + buf_list.push_back(fetch_idxs.back()); + dests[dest_idx].shape_specs.push_back( + {prim_type, xla::DimensionVector(shard.first.dims().begin(), + shard.first.dims().end())}); + } + buffer_list.push_back(std::move(buf_list)); + } + + std::vector< + std::shared_ptr> + atms; + atms.reserve(dests.size()); + + for (auto &dest : dests) { + atms.push_back(xla::ValueOrThrow( + py_client->pjrt_client()->CreateBuffersForAsyncHostToDevice( + dest.shape_specs, std::nullopt, dest.memory_space))); + } + + std::vector> pull_dests; + std::vector buffer_ids; + pull_dests.reserve(fetch_idxs.size()); + buffer_ids.reserve(fetch_idxs.size()); + for (auto &fetch_idx : fetch_idxs) { + auto &atm = atms[fetch_idx.first]; + pull_dests.push_back(MakeDmaDestination( + atm, fetch_idx.second, atm->buffer_size(fetch_idx.second))); + buffer_ids.push_back(static_cast(buffer_ids.size())); + } + + self.Pull(uuid, buffer_ids, std::move(pull_dests)); + + std::vector out; + auto traceback = xla::Traceback::Get(); + for (size_t i = 0; i < buffer_list.size(); ++i) { + xla::ifrt::PjRtArray::PjRtBuffers buffers; + buffers.reserve(buffer_list[i].size()); + for (auto &v : buffer_list[i]) { + buffers.push_back(atms[v.first]->RetrieveBuffer(v.second)); + } + auto arr = xla::ValueOrThrow(xla::ifrt::PjRtArray::Create( + ifrt_client, avals[i].dtype, avals[i].shape, avals[i].sharding, + std::move(buffers), avals[i].layout)); + out.push_back(xla::PyArray::MakeFromIfrtArrayAndSharding( + py_client, traceback, std::move(arr), shardings[i], false, true, + /*skip_checks=*/false)); + } + + return out; + }); + + nb::class_(m, "TransferServer") + .def("address", [](PyTransferServer &self) { return self.address(); }) + .def("_await_pull_flat", + [](PyTransferServer &self, uint64_t uuid, + std::vector inputs) { + std::vector arrs; + arrs.reserve(inputs.size()); + for (const xla::PyArray &input : inputs) { + arrs.push_back(tsl::FormRef(input.ifrt_array())); + } + self.AwaitPull(uuid, arrs); + }) + .def("connect", [](PyTransferServer &self, const std::string &address) { + return self.Connect(address); + }); + + m.def( + "start_transfer_server", + [](xla::nb_class_ptr py_client, std::string address, + std::vector transport_addresses_str, + size_t max_num_parallel_copies, size_t transfer_size, + bool supports_pinned_allocator) -> PyTransferServer { + PyTransferServer result; + std::vector transport_addresses; + transport_addresses.reserve(transport_addresses_str.size()); + for (const std::string &addr : transport_addresses_str) { + transport_addresses.push_back( + xla::ValueOrThrow(SocketAddress::Parse(addr))); + } + xla::ThrowIfError(result.Start( + py_client->ifrt_client(), max_num_parallel_copies, transfer_size, + xla::ValueOrThrow(SocketAddress::Parse(address)), + transport_addresses, supports_pinned_allocator)); + return result; + }, + nb::arg("client"), nb::arg("address") = SocketAddress().ToString(), + nb::arg("transport_addresses") = std::vector(), + nb::arg("max_num_parallel_copies") = 8, + nb::arg("transfer_size") = 256 * 1024 * 1024, + // Dual pinning not confirmed to be supported. + nb::arg("supports_pinned_allocator") = false); +} + +} // namespace aux diff --git a/tests/ci_clangformat/py_socket_transfer.h b/tests/ci_clangformat/py_socket_transfer.h new file mode 100644 index 0000000..ed2a7af --- /dev/null +++ b/tests/ci_clangformat/py_socket_transfer.h @@ -0,0 +1,26 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef JAXLIB_TRANSFER_PY_SOCKET_TRANSFER_H_ +#define JAXLIB_TRANSFER_PY_SOCKET_TRANSFER_H_ + +#include "nanobind/nanobind.h" + +namespace aux { + +void RegisterTransferServerTypes(nanobind::module_ &m); + +} // namespace aux + +#endif // JAXLIB_TRANSFER_PY_SOCKET_TRANSFER_H_ diff --git a/tests/ci_clangformat/py_values.cc b/tests/ci_clangformat/py_values.cc new file mode 100644 index 0000000..d5049ee --- /dev/null +++ b/tests/ci_clangformat/py_values.cc @@ -0,0 +1,1097 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_values.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/functional/any_invocable.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "jaxlib/py_array.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/sharding.h" +#include "jaxlib/to_ifrt_sharding.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/complex.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "tsl/platform/ml_dtypes.h" +#include "tsl/profiler/lib/traceme.h" +#include "xla/primitive_util.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/array_spec.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt/user_context.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/safe_static_init.h" +#include "xla/python/types.h" +#include "xla/shape.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/python/lib/core/numpy.h" +#include "xla/types.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace nb = nanobind; + +namespace xla { + +namespace { + +// Gets the thread-local instance. +static DevicePutInfo &GetDevicePutInfo() { + thread_local DevicePutInfo device_put_info; + return device_put_info; +} + +// Prepared data for creating a single shard of an array. Holds a single-device +// IFRT array or a host buffer. +struct Shard { + explicit Shard(ifrt::ArrayRef ifrt_array, bool weak_type) + : ifrt_array_or_host_buffer(std::move(ifrt_array)), + weak_type(weak_type), + // host_buffer_semantics is not meaningful when + // `ifrt_array_or_host_buffer` is an IFRT Array. + host_buffer_semantics( + ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall) {} + + Shard(ifrt::Client::HostBuffer ifrt_host_buffer, bool weak_type, + ifrt::Client::HostBufferSemantics host_buffer_semantics) + : ifrt_array_or_host_buffer(std::move(ifrt_host_buffer)), + weak_type(weak_type), + host_buffer_semantics(host_buffer_semantics) {} + + Shard(const Shard &) = delete; + Shard &operator=(const Shard &) = delete; + Shard(Shard &&) noexcept = default; + Shard &operator=(Shard &&) noexcept = default; + + bool is_ifrt_array() const { + return std::holds_alternative(ifrt_array_or_host_buffer); + } + ifrt::DType ifrt_dtype() const; + const ifrt::Shape &ifrt_shape() const; + + // Points to the on-device array or on-host buffer. + std::variant + ifrt_array_or_host_buffer; + bool weak_type; + ifrt::Client::HostBufferSemantics host_buffer_semantics; +}; + +// A function that creates a `Shard` from a Python object when called. +using ShardFn = absl::AnyInvocable() &&>; + +absl::StatusOr> StringDTypeArrayToCords( + PyArrayObject *py_array_obj) { + if (PyArray_SIZE(py_array_obj) == 0) { + return absl::InvalidArgumentError("empty numpy array"); + } + + std::vector cords; + cords.reserve(PyArray_SIZE(py_array_obj)); + + auto iter = + nb::steal(PyArray_IterNew(reinterpret_cast(py_array_obj))); + while (PyArray_ITER_NOTDONE(iter.ptr())) { + auto *iter_data = PyArray_ITER_DATA(iter.ptr()); + auto *item = PyArray_GETITEM(py_array_obj, static_cast(iter_data)); + if (!item) { + return absl::InternalError( + "Failed to get elements out of the ndarray iter."); + } + Py_ssize_t len; + auto str = PyUnicode_AsUTF8AndSize(item, &len); + cords.push_back(absl::Cord(absl::string_view(str, len))); + PyArray_ITER_NEXT(iter.ptr()); + } + return cords; +} + +// Handler that creates a `Shard` from a Python object. +using DevicePutHandler = std::function( + nb::handle obj, ifrt::Client *client, ifrt::Device *to_device, + ifrt::MemoryKind to_memory_kind, const DevicePutOptions &options)>; + +// Shared logic that makes an IFRT array (either single-device or multi-device) +// from a fully-replicated `shard` that is created from a host buffer (not from +// an existing IFRT array). `shard` will be consumed. +// +// `user_context` will be used for a new IFRT array created. +// +// Expected to be called without holding GIL. +absl::StatusOr> +MakeIfrtArrayFromFullyReplicatedShard( + ifrt::Client *ifrt_client, ifrt::ShardingRef ifrt_sharding, Shard &shard, + tsl::RCReference user_context) { + auto host_buffer_shard = std::get( + std::move(shard.ifrt_array_or_host_buffer)); + return ifrt_client->MakeArrayFromHostBuffer( + host_buffer_shard.data, host_buffer_shard.dtype, + std::move(host_buffer_shard.shape), + std::move(host_buffer_shard.byte_strides), std::move(ifrt_sharding), + shard.host_buffer_semantics, std::move(host_buffer_shard.on_done), + std::move(user_context)); +} + +// Shared logic that makes a single-device IFRT array from a `shard`. `shard` +// will be consumed. +// +// `user_context` will be used for a new IFRT array created from the host +// buffer, and be not applied when reusing an existing IFRT array. +// +// Expected to be called without holding GIL. +absl::StatusOr MakeSingleDeviceIfrtArrayFromShard( + xla::ifrt::Client *ifrt_client, xla::ifrt::Device *ifrt_device, + xla::ifrt::MemoryKind ifrt_memory_kind, Shard &shard, + tsl::RCReference user_context) { + if (auto *ifrt_array = + std::get_if(&shard.ifrt_array_or_host_buffer)) { + return std::move(*ifrt_array); + } + ifrt::ShardingRef ifrt_sharding = + ifrt::SingleDeviceSharding::Create(ifrt_device, ifrt_memory_kind); + return MakeIfrtArrayFromFullyReplicatedShard( + ifrt_client, std::move(ifrt_sharding), shard, std::move(user_context)); +} + +// Makes an IFRT Array from `shards` using a batched array creation API (fast +// path). `shards` will be consumed. +// +// Expected to be called without holding GIL. +absl::StatusOr MakeIfrtArrayFromShardsInBatch( + ifrt::Client *ifrt_client, ifrt::DType ifrt_dtype, ifrt::Shape ifrt_shape, + ifrt::ShardingRef ifrt_sharding, absl::Span shards, + tsl::RCReference user_context) { + absl::InlinedVector< + std::pair, ifrt::Client::HostBuffer>, 1> + host_buffers; + host_buffers.reserve(shards.size()); + ifrt::Client::HostBufferSemantics safe_host_semantics = + ifrt::Client::HostBufferSemantics::kImmutableZeroCopy; + // TODO(hyeontaek): Deduplicate shards here or early on to create a unique + // HostBuffer for each set of replicated shards. + for (int64_t i = 0; i < shards.size(); ++i) { + host_buffers.push_back({{i}, + std::get(std::move( + shards[i].ifrt_array_or_host_buffer))}); + // The minimum host buffer semantics is a safe semantics that can be used + // for all shards when they are created in a single batch. + safe_host_semantics = + std::min(safe_host_semantics, shards[i].host_buffer_semantics); + } + + std::vector specs; + specs.push_back(ifrt::Client::MakeArraysFromHostBufferShardsSpec{ + std::move(host_buffers), + ifrt::ArraySpec{/*dtype=*/ifrt_dtype, + /*shape=*/std::move(ifrt_shape), + /*sharding=*/std::move(ifrt_sharding), + /*layout=*/nullptr}}); + TF_ASSIGN_OR_RETURN( + auto arrays, + ifrt_client->MakeArraysFromHostBufferShards( + absl::MakeSpan(specs), safe_host_semantics, std::move(user_context))); + return std::move(arrays.front()); +} + +// Makes an IFRT Array from `shards` using an array assembly API (slow path). +// `shards` will be consumed. +// +// Expected to be called without holding GIL. +absl::StatusOr MakeIfrtArrayFromShardsWithAssembly( + ifrt::Client *ifrt_client, ifrt::DType ifrt_dtype, ifrt::Shape ifrt_shape, + ifrt::ShardingRef ifrt_sharding, + ifrt::DeviceList *ifrt_addressable_device_list, + ifrt::MemoryKind ifrt_memory_kind, absl::Span shards, + tsl::RCReference user_context) { + absl::Span ifrt_addressable_devices = + ifrt_addressable_device_list->devices(); + std::vector ifrt_array_shards; + ifrt_array_shards.reserve(shards.size()); + for (int64_t i = 0; i < shards.size(); ++i) { + TF_ASSIGN_OR_RETURN(ifrt::ArrayRef ifrt_array_shard, + MakeSingleDeviceIfrtArrayFromShard( + ifrt_client, ifrt_addressable_devices[i], + ifrt_memory_kind, shards[i], user_context)); + ifrt_array_shards.push_back(std::move(ifrt_array_shard)); + } + return ifrt_client->AssembleArrayFromSingleDeviceArrays( + ifrt_dtype, std::move(ifrt_shape), std::move(ifrt_sharding), + absl::MakeSpan(ifrt_array_shards), ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); +} + +template +absl::StatusOr HandlePythonScalar(nb::handle obj, ifrt::Client *client, + ifrt::Device *to_device, + ifrt::MemoryKind to_memory_kind, + const DevicePutOptions &options) { + T value; + try { + value = nb::cast(obj); + } catch (const std::exception &e) { + return InvalidArgument( + "Unable to convert Python scalar to %s. This most likely means the " + "value (%s) overflows the range of the type.", + PrimitiveType_Name(primitive_util::NativeToPrimitiveType()), + nb::cast(nb::repr(obj))); + } + + std::variant data; + Shape shape; + PrimitiveType type; + if (std::is_same() || !options.squash_64bit_types) { + data.template emplace<0>(value); + type = primitive_util::NativeToPrimitiveType(); + } else { + // TODO(phawkins): we should check for overflow here, e.g., because of bugs + // like https://github.com/google/jax/issues/2006 + data.template emplace<1>(static_cast(value)); + type = primitive_util::NativeToPrimitiveType(); + } + TF_ASSIGN_OR_RETURN(ifrt::DType ifrt_dtype, ifrt::ToDType(type)); + + return [data, ifrt_dtype]() -> absl::StatusOr { + const void *ptr = std::visit( + [](const auto &v) { return static_cast(&v); }, data); + ifrt::Client::HostBuffer ifrt_host_buffer{ + ptr, ifrt_dtype, ifrt::Shape({}), + /*byte_strides=*/std::nullopt, + /*on_done_with_host_buffer=*/nullptr}; + return Shard(std::move(ifrt_host_buffer), /*weak_type=*/true, + ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall); + }; +} + +absl::StatusOr HandlePythonInt(nb::handle obj, ifrt::Client *client, + ifrt::Device *to_device, + ifrt::MemoryKind to_memory_kind, + const DevicePutOptions &options) { + PrimitiveType type; + std::variant data; + + if (options.squash_64bit_types) { + try { + data.emplace<1>(nb::cast(obj)); + } catch (const std::exception &e) { + return InvalidArgument( + "Unable to convert Python scalar to %s. This most likely means the " + "value (%s) overflows the range of the type.", + PrimitiveType_Name(primitive_util::NativeToPrimitiveType()), + nb::cast(nb::repr(obj))); + } + type = S32; + } else { + try { + data.emplace<0>(nb::cast(obj)); + } catch (const std::exception &e) { + return InvalidArgument( + "Unable to convert Python scalar to %s. This most likely means the " + "value (%s) overflows the range of the type.", + PrimitiveType_Name(primitive_util::NativeToPrimitiveType()), + nb::cast(nb::repr(obj))); + } + type = S64; + } + TF_ASSIGN_OR_RETURN(ifrt::DType ifrt_dtype, ifrt::ToDType(type)); + return [data, ifrt_dtype]() -> absl::StatusOr { + const void *ptr = std::visit( + [](const auto &v) { return static_cast(&v); }, data); + ifrt::Client::HostBuffer ifrt_host_buffer{ + ptr, ifrt_dtype, ifrt::Shape({}), + /*byte_strides=*/std::nullopt, + /*on_done_with_host_buffer=*/nullptr}; + return Shard(std::move(ifrt_host_buffer), /*weak_type=*/true, + ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall); + }; +} + +template +absl::StatusOr HandleNumpyScalar(nb::handle h, ifrt::Client *client, + ifrt::Device *to_device, + ifrt::MemoryKind to_memory_kind, + const DevicePutOptions &options) { + std::variant data; + PrimitiveType type; + // For extension types, ScalarAsCtype returns a pointer to the data. + if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = S2; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = S4; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = U2; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = U4; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = BF16; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F4E2M1FN; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E3M4; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E4M3; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E4M3FN; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E4M3B11FNUZ; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E5M2; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E4M3FNUZ; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E5M2FNUZ; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E8M0FNU; + } else if (std::is_same() || !options.squash_64bit_types) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<0>()); + type = primitive_util::NativeToPrimitiveType(); + } else { + T value; + PyArray_ScalarAsCtype(h.ptr(), &value); + data.template emplace<1>(static_cast(value)); + type = primitive_util::NativeToPrimitiveType(); + } + std::shared_ptr py_buffer_ref; + if (data.index() == 2) { + py_buffer_ref = + GlobalPyRefManager()->ManageReference(nb::cast(h)); + } + TF_ASSIGN_OR_RETURN(ifrt::DType ifrt_dtype, ifrt::ToDType(type)); + return [data, py_buffer_ref = std::move(py_buffer_ref), + ifrt_dtype]() mutable -> absl::StatusOr { + const void *ptr = std::visit( + [](const auto &v) -> const void * { + if constexpr (std::is_same_v, void *>) { + return v; + } else { + return static_cast(&v); + } + }, + data); + ifrt::Client::HostBuffer ifrt_host_buffer{ + ptr, ifrt_dtype, ifrt::Shape({}), + /*byte_strides=*/std::nullopt, + /*on_done_with_host_buffer=*/ + [py_buffer_ref = + std::move(py_buffer_ref)]() { /* keeps py_buffer_ref alive */ }}; + return Shard(std::move(ifrt_host_buffer), /*weak_type=*/false, + ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall); + }; +} + +absl::StatusOr HandleStringNumpyArray( + nb::handle h, ifrt::Client *client, ifrt::Device *to_device, + ifrt::MemoryKind to_memory_kind, const DevicePutOptions &options) { + xla::nb_numpy_ndarray array = nb::cast(h); + auto py_array_obj = reinterpret_cast(array.ptr()); + TF_ASSIGN_OR_RETURN(auto cords, StringDTypeArrayToCords(py_array_obj)); + + // Assemble all the parameters of MakeArrayFromHostBuffer + const void *data = cords.data(); + + // Make an explicit copy of the shape elements so we won't run into complex + // endianness and precision issues that might arise if we reinterpret-casted + // from npy_intp, that can be just 32 bits-wide in some environments + // such as macos_arm64 to const int64_t* that must be 64 bits-wide. + ifrt::Shape::Dimensions dims; + dims.reserve(array.ndim()); + for (int i = 0; i < array.ndim(); ++i) { + dims.push_back(array.shape(i)); + } + ifrt::Shape shape(std::move(dims)); + + auto on_done_with_host_buffer = [cords = std::move(cords)] {}; + + return [data, shape = std::move(shape), + on_done_with_host_buffer = std::move( + on_done_with_host_buffer)]() mutable -> absl::StatusOr { + ifrt::Client::HostBuffer ifrt_host_buffer{ + data, ifrt::DType(ifrt::DType::kString), std::move(shape), + /*byte_strides=*/std::nullopt, std::move(on_done_with_host_buffer)}; + return Shard( + std::move(ifrt_host_buffer), /*weak_type=*/false, + ifrt::Client::HostBufferSemantics::kImmutableUntilTransferCompletes); + }; +} + +absl::StatusOr HandleNumpyArray(nb::handle h, ifrt::Client *client, + ifrt::Device *to_device, + ifrt::MemoryKind to_memory_kind, + const DevicePutOptions &options) { + xla::nb_numpy_ndarray array = nb::cast(h); + + // String numpy arrays require substantially different processing. + if (array.dtype().char_() == (int)'T' || array.dtype().kind() == 'T') { + return HandleStringNumpyArray(h, client, to_device, to_memory_kind, + options); + } + + TF_ASSIGN_OR_RETURN(PrimitiveType type, DtypeToPrimitiveType(array.dtype())); + + PrimitiveType squashed_type; + if (options.squash_64bit_types) { + squashed_type = Squash64BitTypes(type); + if (squashed_type != type) { + TF_ASSIGN_OR_RETURN(xla::nb_dtype squashed_dtype, + PrimitiveTypeToNbDtype(squashed_type)); + array = nb::steal(PyArray_CastToType( + reinterpret_cast(array.ptr()), + reinterpret_cast(squashed_dtype.release().ptr()), + /*fortran=*/0)); + } + } else { + squashed_type = type; + } + + absl::InlinedVector dims(array.ndim()); + ifrt::Client::HostBuffer::ByteStrides byte_strides(array.ndim()); + for (int i = 0; i < array.ndim(); ++i) { + dims[i] = array.shape(i); + byte_strides[i] = array.strides(i); + } + const void *data = array.data(); + std::shared_ptr py_buffer_ref = + GlobalPyRefManager()->ManageReference(std::move(array)); + TF_ASSIGN_OR_RETURN(ifrt::DType ifrt_dtype, ifrt::ToDType(squashed_type)); + return [data, ifrt_dtype, dims = std::move(dims), + byte_strides = std::move(byte_strides), + py_buffer_ref = std::move(py_buffer_ref), + allow_zero_copy = + options.allow_zero_copy]() mutable -> absl::StatusOr { + ifrt::Client::HostBufferSemantics host_buffer_semantics = + ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall; + std::function on_done_with_host_buffer; + if (allow_zero_copy) { + on_done_with_host_buffer = + [py_buffer_ref{ + std::move(py_buffer_ref)}]() { /* keeps py_buffer_ref alive */ }; + host_buffer_semantics = + ifrt::Client::HostBufferSemantics::kImmutableZeroCopy; + } + + ifrt::Client::HostBuffer ifrt_host_buffer{ + data, ifrt_dtype, ifrt::Shape(dims), std::move(byte_strides), + std::move(on_done_with_host_buffer)}; + return Shard(std::move(ifrt_host_buffer), /*weak_type=*/false, + host_buffer_semantics); + }; +} + +absl::StatusOr HandlePyArray(nb::handle obj, ifrt::Client *client, + ifrt::Device *to_device, + ifrt::MemoryKind to_memory_kind, + const DevicePutOptions &options) { + auto py_array = nb::borrow(obj); + + // We only allow single device case for PyArray in device put. + if (py_array.num_shards() != 1) { + return InvalidArgument( + "device_put expects an array with exactly one shard, got an array with " + "with %d shards.", + py_array.num_shards()); + } + + ifrt::Array *ifrt_array = py_array.ifrt_array(); + if (ifrt_array == nullptr) { + return InvalidArgument("Array has been deleted."); + } + + // Fallback to python for non-matching clients or pmap sharding. + if (py_array.sharding().type().ptr() == jax::PmapSharding::type().ptr() || + ifrt_array->sharding().devices()->devices().front()->client() != + to_device->client()) { + return HandleNumpyArray(obj.attr("_value"), client, to_device, + to_memory_kind, options); + } + + if (ifrt_array->sharding().devices()->devices().front() == to_device && + options.allow_zero_copy && + (!to_memory_kind.memory_kind().has_value() || + !ifrt_array->sharding().memory_kind().memory_kind().has_value() || + ifrt_array->sharding().memory_kind() == to_memory_kind)) { + Shard result(tsl::FormRef(ifrt_array), py_array.weak_type()); + return [result = std::move(result)]() mutable { return std::move(result); }; + } else { + return [ifrt_array = tsl::FormRef(ifrt_array), to_device, to_memory_kind, + weak_type = py_array.weak_type(), + allow_zero_copy = + options.allow_zero_copy]() mutable -> absl::StatusOr { + auto *ifrt_client = ifrt_array->client(); + TF_ASSIGN_OR_RETURN( + auto copied_ifrt_arrays, + ifrt_client->CopyArrays( + absl::MakeSpan(&ifrt_array, 1), + ifrt_client->MakeDeviceList({to_device}), to_memory_kind, + allow_zero_copy ? ifrt::ArrayCopySemantics::kReuseInput + : ifrt::ArrayCopySemantics::kAlwaysCopy)); + return Shard(std::move(copied_ifrt_arrays.front()), weak_type); + }; + } +} + +ifrt::DType Shard::ifrt_dtype() const { + if (is_ifrt_array()) { + return std::get(ifrt_array_or_host_buffer)->dtype(); + } else { + return std::get(ifrt_array_or_host_buffer).dtype; + } +} + +const ifrt::Shape &Shard::ifrt_shape() const { + if (is_ifrt_array()) { + return std::get(ifrt_array_or_host_buffer)->shape(); + } else { + return std::get(ifrt_array_or_host_buffer).shape; + } +} + +// Creates a `ShardFn` that copies `arg` to `to_device` and `to_memory_kind`. +// +// Requires GIL. The returned `ShardFn` should be called without GIL held. +absl::StatusOr MakeShardFn(nb::handle arg, ifrt::Client *client, + ifrt::Device *to_device, + ifrt::MemoryKind to_memory_kind, + const DevicePutOptions &options) { + using PyObjectDeviceHandlerMap = + absl::flat_hash_map; + + auto init_fn = []() { + std::unique_ptr p = + std::make_unique(); + + const NumpyScalarTypes &dtypes = GetNumpyScalarTypes(); + // Python scalar types. + static_assert(sizeof(bool) == 1, "Conversion code assumes bool is 1 byte"); + (*p)[reinterpret_cast(&PyBool_Type)] = + HandlePythonScalar; + (*p)[reinterpret_cast(&PyLong_Type)] = HandlePythonInt; + (*p)[reinterpret_cast(&PyFloat_Type)] = + HandlePythonScalar; + (*p)[reinterpret_cast(&PyComplex_Type)] = + HandlePythonScalar; + + (*p)[reinterpret_cast(&PyArray_Type)] = HandleNumpyArray; + + // Numpy scalar types. For some of them, we share the handler with + // Python types (np_int64, np_float64, np_complex128). + (*p)[dtypes.np_bool.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_int4.ptr()] = HandleNumpyScalar; + if (dtypes.np_int2.has_value()) { + (*p)[dtypes.np_int2->ptr()] = HandleNumpyScalar; + } + (*p)[dtypes.np_int8.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_int16.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_int32.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_int64.ptr()] = HandleNumpyScalar; + if (dtypes.np_uint2.has_value()) { + (*p)[dtypes.np_uint2->ptr()] = HandleNumpyScalar; + } + (*p)[dtypes.np_uint4.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_uint8.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_uint16.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_uint32.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_uint64.ptr()] = HandleNumpyScalar; + if (dtypes.np_float4_e2m1fn.has_value()) { + (*p)[dtypes.np_float4_e2m1fn->ptr()] = + HandleNumpyScalar; + } + if (dtypes.np_float8_e3m4.has_value()) { + (*p)[dtypes.np_float8_e3m4->ptr()] = HandleNumpyScalar; + } + if (dtypes.np_float8_e4m3.has_value()) { + (*p)[dtypes.np_float8_e4m3->ptr()] = HandleNumpyScalar; + } + (*p)[dtypes.np_float8_e4m3fn.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = + HandleNumpyScalar; + (*p)[dtypes.np_float8_e5m2.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float8_e4m3fnuz.ptr()] = + HandleNumpyScalar; + (*p)[dtypes.np_float8_e5m2fnuz.ptr()] = + HandleNumpyScalar; + if (dtypes.np_float8_e8m0fnu.has_value()) { + (*p)[dtypes.np_float8_e8m0fnu->ptr()] = + HandleNumpyScalar; + } + (*p)[dtypes.np_bfloat16.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float16.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float32.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float64.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_complex64.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_complex128.ptr()] = HandleNumpyScalar; + static_assert(sizeof(long long) == sizeof(int64_t), // NOLINT + "long long must be the same size as int64_t"); + (*p)[dtypes.np_longlong.ptr()] = HandleNumpyScalar; + static_assert(sizeof(int) == sizeof(int32_t), + "int must be the same size as int32_t"); + (*p)[dtypes.np_intc.ptr()] = HandleNumpyScalar; + return p; + }; + const PyObjectDeviceHandlerMap &handlers = + xla::SafeStaticInit(init_fn); + + if (arg.type().ptr() == PyArray::type().ptr()) { + auto array = nb::borrow(arg); + return HandlePyArray(arg, client, to_device, to_memory_kind, options); + } + + auto res = handlers.find(arg.type().ptr()); + if (res == handlers.end()) { + for (auto base_class : arg.type().attr("__mro__")) { + res = handlers.find(base_class.ptr()); + if (res != handlers.end()) { + return res->second(arg, client, to_device, to_memory_kind, options); + } + } + return InvalidArgument( + "%s", absl::StrCat( + "Not supported: The C++ jax jit execution path, only accepts " + "DeviceArray, Numpy arrays scalars of supported types " + "(see implementation), or Python scalars. Got type ", + nb::cast(nb::str(arg.type())))); + } + return res->second(arg, client, to_device, to_memory_kind, options); +} + +} // namespace + +bool IsFloat0(xla::nb_numpy_ndarray arg) { + static const auto *dtypes_module = + new nb::module_(nb::module_::import_("jax.dtypes")); + static const auto *float0_dtype = + new nb::handle(dtypes_module->attr("float0")); + return float0_dtype->is(arg.attr("dtype")); +} + +std::string PyArgSignature::DebugString() const { + std::string result = ""; + if (weak_type) { + absl::StrAppend(&result, "weak_"); + } + absl::StrAppend(&result, xla::PrimitiveType_Name(dtype)); + absl::StrAppend(&result, "[", absl::StrJoin(shape, ","), "]"); + return result; +} + +using ToPyArgSignatureHandler = + std::function(nb::handle, bool)>; + +absl::StatusOr PyArgSignatureOfValue(nb::handle arg, + bool jax_enable_x64) { + static const absl::flat_hash_map + *const handlers = [] { + auto p = new absl::flat_hash_map(); + + const NumpyScalarTypes &dtypes = GetNumpyScalarTypes(); + + // The 4 Python native types. + ToPyArgSignatureHandler bool_handler = + [](nb::handle, bool) -> absl::StatusOr { + return PyArgSignature(PrimitiveType::PRED, {}, true); + }; + ToPyArgSignatureHandler int_handler = + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + // TODO(phawkins): we should consider checking for integer overflow. + if (jax_enable_x64) { + return PyArgSignature(PrimitiveType::S64, {}, true); + } else { + return PyArgSignature(PrimitiveType::S32, {}, true); + } + }; + ToPyArgSignatureHandler float_handler = + [&dtypes](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + // Only Python native types has a True weak_type. + bool weak_type = !nb::isinstance(h, dtypes.np_float64); + if (jax_enable_x64) { + return PyArgSignature(PrimitiveType::F64, {}, weak_type); + } else { + return PyArgSignature(PrimitiveType::F32, {}, weak_type); + } + }; + ToPyArgSignatureHandler complex_handler = + [&dtypes](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + // Note that this branch is also taken for np.complex128: + // isinstance(np.complex128(3), complex) returns True + // isinstance(np.complex64(3), complex) returns False + bool weak_type = !nb::isinstance(h, dtypes.np_complex128); + if (jax_enable_x64) { + return PyArgSignature(PrimitiveType::C128, {}, weak_type); + } else { + return PyArgSignature(PrimitiveType::C64, {}, weak_type); + } + }; + + (*p)[reinterpret_cast(&PyBool_Type)] = bool_handler; + (*p)[reinterpret_cast(&PyLong_Type)] = int_handler; + (*p)[reinterpret_cast(&PyFloat_Type)] = float_handler; + (*p)[reinterpret_cast(&PyComplex_Type)] = complex_handler; + + ToPyArgSignatureHandler numpy_handler = + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + xla::nb_numpy_ndarray numpy_array = + nb::cast(h); + TF_ASSIGN_OR_RETURN(PrimitiveType dtype, + DtypeToPrimitiveType(numpy_array.dtype())); + if (!jax_enable_x64) { + dtype = Squash64BitTypes(dtype); + } + // We use reinterpret_cast<> to defend against environments where + // ssize_t may not be precisely the same type as int64_t, even if it + // is the same size (long vs long long). + static_assert(sizeof(int64_t) == sizeof(ssize_t), + "Code assumes ssize_t is the same as int64_t"); + return PyArgSignature( + dtype, + absl::MakeConstSpan( + reinterpret_cast(numpy_array.shape()), + numpy_array.ndim()), + /*weak_type=*/false); + }; + (*p)[reinterpret_cast(&PyArray_Type)] = numpy_handler; + + ToPyArgSignatureHandler np_uint64_handler = + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + if (jax_enable_x64) { + return PyArgSignature(PrimitiveType::U64, {}, /*weak_type=*/false); + } else { + return PyArgSignature(PrimitiveType::U32, {}, /*weak_type=*/false); + } + }; + ToPyArgSignatureHandler np_int_handler = + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + if (jax_enable_x64) { + return PyArgSignature(PrimitiveType::S64, {}, /*weak_type=*/false); + } else { + return PyArgSignature(PrimitiveType::S32, {}, /*weak_type=*/false); + } + }; + ToPyArgSignatureHandler numpy_array_handler = + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + // This block deals with all numpy scalar types, except for int64_dt, + // float64_dt and complex128_dt which are taken care of in previous if + // blocks. + TF_ASSIGN_OR_RETURN(auto dtype, + DtypeToPrimitiveType(h.attr("dtype"))); + return PyArgSignature(dtype, {}, /*weak_type=*/false); + }; + + // This block deals with all numpy scalar types, except for int64_dt, + // float64_dt and complex128_dt which are taken care of in previous if + // blocks. + (*p)[dtypes.np_bool.ptr()] = numpy_array_handler; + (*p)[dtypes.np_int4.ptr()] = numpy_array_handler; + (*p)[dtypes.np_int8.ptr()] = numpy_array_handler; + (*p)[dtypes.np_int16.ptr()] = numpy_array_handler; + (*p)[dtypes.np_int32.ptr()] = numpy_array_handler; + (*p)[dtypes.np_int64.ptr()] = np_int_handler; + (*p)[dtypes.np_uint4.ptr()] = numpy_array_handler; + (*p)[dtypes.np_uint8.ptr()] = numpy_array_handler; + (*p)[dtypes.np_uint16.ptr()] = numpy_array_handler; + (*p)[dtypes.np_uint32.ptr()] = numpy_array_handler; + (*p)[dtypes.np_uint64.ptr()] = np_uint64_handler; + // TODO(upwind): Explore if we can remove std::optional for these types + // in xla/python/types.h and xla/python/types.cc + if (dtypes.np_float4_e2m1fn.has_value()) { + (*p)[dtypes.np_float4_e2m1fn->ptr()] = numpy_array_handler; + } + if (dtypes.np_float8_e3m4.has_value()) { + (*p)[dtypes.np_float8_e3m4->ptr()] = numpy_array_handler; + } + if (dtypes.np_float8_e4m3.has_value()) { + (*p)[dtypes.np_float8_e4m3->ptr()] = numpy_array_handler; + } + (*p)[dtypes.np_float8_e4m3fn.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e4m3fnuz.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e5m2.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e5m2fnuz.ptr()] = numpy_array_handler; + if (dtypes.np_float8_e8m0fnu.has_value()) { + (*p)[dtypes.np_float8_e8m0fnu->ptr()] = numpy_array_handler; + } + (*p)[dtypes.np_float16.ptr()] = numpy_array_handler; + (*p)[dtypes.np_bfloat16.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float32.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float64.ptr()] = float_handler; + (*p)[dtypes.np_complex64.ptr()] = numpy_array_handler; + (*p)[dtypes.np_complex128.ptr()] = complex_handler; + (*p)[dtypes.np_longlong.ptr()] = np_int_handler; + (*p)[dtypes.np_intc.ptr()] = numpy_array_handler; + + return p; + }(); + + if (arg.type().ptr() == PyArray::type().ptr()) { + auto array = nb::borrow(arg); + ifrt::Array *ifrt_array = array.ifrt_array(); + if (ifrt_array == nullptr) { + return xla::InvalidArgument("Array has been deleted."); + } + TF_ASSIGN_OR_RETURN(auto primitive_type, + ifrt::ToPrimitiveType(ifrt_array->dtype())); + return PyArgSignature(primitive_type, array.shape(), array.weak_type()); + } + + auto res = handlers->find(arg.type().ptr()); + if (res == handlers->end()) { + // We attempt to look at the MRO classes + for (auto base_class : arg.type().attr("__mro__")) { + res = handlers->find(base_class.ptr()); + if (res != handlers->end()) { + return res->second(arg, jax_enable_x64); + } + } + return InvalidArgument( + "%s", + absl::StrCat("Not supported: The C++ ToPyArgSignature only accepts " + "Buffer/DeviceArray, Numpy " + "arrays scalars of supported types " + "(see implementation), or Python scalars. Got type ", + nb::cast(nb::str(arg.type())))); + } + return res->second(arg, jax_enable_x64); +} + +absl::StatusOr DevicePutWithDevice( + nanobind::handle addressable_shard, ifrt::Client *ifrt_client, + ifrt::Device *ifrt_device, ifrt::MemoryKind ifrt_memory_kind, + const DevicePutOptions &options) { + tsl::profiler::TraceMe traceme("DevicePut"); + ++GetDevicePutInfo().device_put_with_device; + + if (!ifrt_device->IsAddressable()) { + return InvalidArgument("Cannot copy array to non-addressable device: %s", + ifrt_device->DebugString()); + } + + TF_ASSIGN_OR_RETURN(ShardFn shard_fn, + MakeShardFn(addressable_shard, ifrt_client, ifrt_device, + ifrt_memory_kind, options)); + + tsl::RCReference ifrt_user_context = + ifrt_client->CreateUserContext(); + + nb::gil_scoped_release gil_release; + + TF_ASSIGN_OR_RETURN(Shard shard, std::move(shard_fn)()); + TF_ASSIGN_OR_RETURN(ifrt::ArrayRef ifrt_array, + MakeSingleDeviceIfrtArrayFromShard( + ifrt_client, ifrt_device, ifrt_memory_kind, shard, + std::move(ifrt_user_context))); + return DevicePutResult(std::move(ifrt_array), shard.weak_type); +} + +absl::StatusOr DevicePutWithSharding( + absl::Span addressable_shards, + ifrt::Client *ifrt_client, const nb_dtype &dtype, + absl::Span shape, nanobind::handle sharding, + const DevicePutOptions &options) { + tsl::profiler::TraceMe traceme("DevicePutWithSharding"); + ++GetDevicePutInfo().device_put_with_sharding; + + TF_ASSIGN_OR_RETURN(ifrt::DeviceListRef ifrt_device_list, + GetIfrtDeviceList(sharding)); + ifrt::DeviceList *ifrt_addressable_device_list = + ifrt_device_list->AddressableDeviceList(); + absl::Span ifrt_addressable_devices = + ifrt_addressable_device_list->devices(); + // Pmap sharding requires special handling because it needs a shard shape + // upfront. + const bool is_pmap_sharding = sharding.type().is(jax::PmapSharding::type()); + + if (addressable_shards.size() != ifrt_addressable_devices.size()) { + // Try to generate a friendly error message if the user attempted to copy to + // a non-addressable device. + if (addressable_shards.size() > ifrt_addressable_devices.size()) { + for (ifrt::Device *device : ifrt_device_list->devices()) { + if (!device->IsAddressable()) { + return InvalidArgument( + "Cannot copy array to non-addressable device: %s", + device->DebugString()); + } + } + } + // Otherwise, generate a generic error message. + return InvalidArgument( + "Number of addressable shard data does not match the number " + "of addressable devices in the sharding: %d vs. %d", + addressable_shards.size(), ifrt_addressable_devices.size()); + } + if (is_pmap_sharding && addressable_shards.empty()) { + return InvalidArgument( + "Pmap sharding requires at least one addressable shard."); + } + + TF_ASSIGN_OR_RETURN(ifrt::DType ifrt_dtype, DtypeToIfRtDType(dtype)); + ifrt::Shape ifrt_shape(shape); + ifrt::MemoryKind ifrt_memory_kind = GetMemoryKind(sharding); + + std::vector shard_fns; + shard_fns.reserve(addressable_shards.size()); + for (int i = 0; i < addressable_shards.size(); ++i) { + TF_ASSIGN_OR_RETURN( + ShardFn shard, + MakeShardFn(addressable_shards[i], ifrt_client, + ifrt_addressable_devices[i], ifrt_memory_kind, options)); + shard_fns.push_back(std::move(shard)); + } + + ifrt::ShardingRef ifrt_sharding; + bool is_fully_replicated; + if (is_pmap_sharding) { + CHECK(!shard_fns.empty()); + // IFRT Sharding will be determined once we discover the shard shape. + is_fully_replicated = false; + } else { + TF_ASSIGN_OR_RETURN(ifrt_sharding, + GetIfrtHloSharding(sharding, ifrt_shape)); + // Fully-replicated shardings enable additional optimizations of using a + // single host buffer. + // TODO(hyeontaek): Enable a similar optimization for partially replicated + // cases to reduce the number of host buffers to obtain. + is_fully_replicated = ifrt_sharding->IsFullyReplicated(); + } + tsl::RCReference ifrt_user_context = + ifrt_client->CreateUserContext(); + + nb::gil_scoped_release gil_release; + + // Whether to build an IFRT array from host buffers as a single batch. We do + // not batch any shard is already an IFRT array. + bool should_batch = true; + + std::vector shards; + shards.reserve(shard_fns.size()); + for (int64_t i = 0; i < shard_fns.size(); ++i) { + TF_ASSIGN_OR_RETURN(Shard shard, std::move(shard_fns[i])()); + if (shard.is_ifrt_array()) { + // If any shard is an IFRT array, we should assemble shards. + should_batch = false; + } + shards.push_back(std::move(shard)); + if (should_batch && is_fully_replicated) { + // We need only one host buffer for a fully-replicated array. + break; + } + } + // While we have finished calling `shard_fns`, we cannot destroy them until we + // make a call to IFRT array creation. Destroying `shard_fns` would release + // host buffers prematurely and can cause the array creation API to see + // garbage data. + + // TODO(emilyaf): Remove the following and just use ifrt_dtype when tokens are + // supported. + if (!shards.empty()) { + ifrt_dtype = shards.front().ifrt_dtype(); + } + if (is_pmap_sharding) { + ifrt_sharding = ifrt::ConcreteEvenSharding::Create( + ifrt::DeviceListRef(tsl::FormRef(ifrt_addressable_device_list)), + ifrt_memory_kind, ifrt_shape, + /*shard_shape=*/shards.front().ifrt_shape(), + /*is_fully_replicated=*/false); + } + + ifrt::ArrayRef ifrt_array; + if (should_batch) { + if (is_fully_replicated && shards.size() == 1) { + ++GetDevicePutInfo().device_put_fully_replicated; + TF_ASSIGN_OR_RETURN( + ifrt_array, MakeIfrtArrayFromFullyReplicatedShard( + ifrt_client, std::move(ifrt_sharding), shards.front(), + std::move(ifrt_user_context))); + } else { + ++GetDevicePutInfo().device_put_batched; + TF_ASSIGN_OR_RETURN(ifrt_array, + MakeIfrtArrayFromShardsInBatch( + ifrt_client, ifrt_dtype, std::move(ifrt_shape), + std::move(ifrt_sharding), absl::MakeSpan(shards), + std::move(ifrt_user_context))); + } + } else { + ++GetDevicePutInfo().device_put_assembled; + TF_ASSIGN_OR_RETURN( + ifrt_array, MakeIfrtArrayFromShardsWithAssembly( + ifrt_client, ifrt_dtype, std::move(ifrt_shape), + std::move(ifrt_sharding), ifrt_addressable_device_list, + ifrt_memory_kind, absl::MakeSpan(shards), + std::move(ifrt_user_context))); + } + const bool weak_type = shards.empty() ? false : shards.front().weak_type; + return DevicePutResult(std::move(ifrt_array), weak_type); +} + +std::unordered_map DevicePutInfo::GetInfo() { + const DevicePutInfo &info = GetDevicePutInfo(); + return std::unordered_map({ + {"device_put_with_device", info.device_put_with_device}, + {"device_put_with_sharding", info.device_put_with_sharding}, + {"device_put_fully_replicated", info.device_put_fully_replicated}, + {"device_put_batched", info.device_put_batched}, + {"device_put_assembled", info.device_put_assembled}, + }); +} + +} // namespace xla diff --git a/tests/ci_clangformat/py_values.h b/tests/ci_clangformat/py_values.h new file mode 100644 index 0000000..05ef2ab --- /dev/null +++ b/tests/ci_clangformat/py_values.h @@ -0,0 +1,161 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Helpers for converting Python values into buffers. + +#ifndef JAXLIB_PY_VALUES_H_ +#define JAXLIB_PY_VALUES_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/nb_numpy.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +struct DevicePutResult { + DevicePutResult(ifrt::ArrayRef ifrt_array, bool weak_type) + : ifrt_array(std::move(ifrt_array)), weak_type(weak_type) {} + + // Disallow copy. `DevicePutResult` is expected to be consumed by one user. + DevicePutResult(const DevicePutResult &) = delete; + DevicePutResult &operator=(const DevicePutResult &) = delete; + DevicePutResult(DevicePutResult &&) noexcept = default; + DevicePutResult &operator=(DevicePutResult &&) noexcept = default; + + // Points to the on-device array. + ifrt::ArrayRef ifrt_array; + bool weak_type; +}; + +// Options for `DevicePut`. +struct DevicePutOptions { + bool squash_64bit_types = false; + bool allow_zero_copy = true; +}; + +// Copies a buffer-like object to be on device. This version is designed for +// creating a single-device array. +// +// If `addressable_shard` is not convertible to a `PjRtBuffer` from C++, an +// error will be returned; float0s are not supported yet. +// +// If the value is known to be a PyBuffer object, py_buffer can be passed as an +// optimization to avoid a Python->C++ cast. +// +// Requires GIL. This function performs Python work inline, and runs expensive +// C++ work with GIL temporarily released. +// +// May throw exceptions from nanobind in addition to failing via an error +// absl::Status. (We could catch these if needed, but there seems little point.) +absl::StatusOr DevicePutWithDevice( + nanobind::handle addressable_shard, ifrt::Client *ifrt_client, + ifrt::Device *ifrt_device, ifrt::MemoryKind ifrt_memory_kind, + const DevicePutOptions &options); + +// Copies a buffer-like object to be on device. This version is optimized for +// creating a multi-device array. +// +// `addressable_shards` is a list of buffer-like objects to be copied to +// addressable devices specified in `sharding`. +// +// `shape` and `sharding` determine the shape and sharding of the returned IFRT +// Array. +// +// The size of `addressable_shards` must match the number of addressable devices +// in `sharding`. For a Pmap sharding, there must be at least one addressable +// device. +// +// Requires GIL. This function performs Python work inline, and runs expensive +// C++ work with GIL temporarily released. +// +// See the above `DevicePutWithDevice` for other details. +absl::StatusOr DevicePutWithSharding( + absl::Span addressable_shards, + ifrt::Client *ifrt_client, const nb_dtype &dtype, + absl::Span shape, nanobind::handle sharding, + const DevicePutOptions &options); + +// Returns `true` if `arg` is a JAX float0 array. +bool IsFloat0(xla::nb_numpy_ndarray arg); + +// Describes the abstract shape and dtype of an argument. +struct PyArgSignature { + PyArgSignature(PrimitiveType dtype, absl::Span shape, + bool weak_type) + : dtype(dtype), shape(shape.begin(), shape.end()), weak_type(weak_type) {} + // This is the XLA dtype of the object. + const PrimitiveType dtype; + const absl::InlinedVector shape; + // JAX arguments can be of weak type, if and only if they are Python scalars + // or `DeviceArray` values such that `aval.weak_type` is true. + const bool weak_type; + bool operator==(const PyArgSignature &other) const { + return std::tie(dtype, weak_type, shape) == + std::tie(other.dtype, other.weak_type, other.shape); + } + bool operator!=(const PyArgSignature &other) const { + return !(*this == other); + } + std::string DebugString() const; +}; + +// Returns the PyArgSignature associated with an argument. Returns an error if +// the argument is not supported. +absl::StatusOr PyArgSignatureOfValue(nanobind::handle arg, + bool jax_enable_x64); + +template +H AbslHashValue(H h, const xla::PyArgSignature &s) { + h = H::combine(std::move(h), s.dtype); + h = H::combine_contiguous(std::move(h), s.shape.data(), s.shape.size()); + return h; +} + +// Tracks the number of DevicePut calls and subcases. For testing. +struct DevicePutInfo { + // DevicePutWithDevice call count. + int device_put_with_device = 0; + + // DevicePutWithSharding call count. + int device_put_with_sharding = 0; + + // DevicePutWithSharding with a fully replicated sharding. + int device_put_fully_replicated = 0; + // DevicePutWithSharding that made a batched array creation call. + int device_put_batched = 0; + // DevicePutWithSharding that made per-shard creation calls followed by an + // assembly call. + int device_put_assembled = 0; + + // Returns a map of the counters for the current thread. + static std::unordered_map GetInfo(); +}; + +} // namespace xla + +#endif // JAXLIB_PY_VALUES_H_ diff --git a/tests/ci_clangformat/python_ref_manager.cc b/tests/ci_clangformat/python_ref_manager.cc new file mode 100644 index 0000000..e27ab48 --- /dev/null +++ b/tests/ci_clangformat/python_ref_manager.cc @@ -0,0 +1,106 @@ +/* Copyright 2019 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/python_ref_manager.h" + +#include + +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" + +namespace xla { + +namespace nb = nanobind; + +PythonRefManager::ManagedPyObjects::ManagedPyObjects( + PythonRefManager *manager, absl::Span objects) + : manager_(manager) { + objects_.reserve(objects.size()); + for (nb::object &object : objects) { + objects_.push_back(std::move(object)); + } +} + +PythonRefManager::ManagedPyObjects::~ManagedPyObjects() { + if (manager_ && !objects_.empty()) { + manager_->AddGarbage(absl::MakeSpan(objects_)); + } +} + +std::shared_ptr +PythonRefManager::ManageReference(nb::object object) { + return std::make_shared(this, + absl::Span(&object, 1)); +} + +std::shared_ptr +PythonRefManager::ManageReferences(absl::Span objects) { + return std::make_shared(this, objects); +} + +void PythonRefManager::AddGarbage(nb::object garbage) { + absl::MutexLock lock(&mu_); + // We want to collect arbitrary python garbage (e.g., buffers) aggressively. + garbage_count_.fetch_add(100, std::memory_order_relaxed); + python_garbage_.push_back(std::move(garbage)); +} + +void PythonRefManager::AddGarbage(absl::Span garbage) { + absl::MutexLock lock(&mu_); + // We want to collect arbitrary python garbage (e.g., buffers) aggressively. + garbage_count_.fetch_add(100, std::memory_order_relaxed); + for (nb::object &o : garbage) { + python_garbage_.push_back(std::move(o)); + } +} + +void PythonRefManager::AddGarbage( + absl::Span const> garbage) { + absl::MutexLock lock(&mu_); + // We don't care about collecting stack frame objects often. We grab a lot of + // tracebacks and the code objects are most likely live for the entire + // process. + garbage_count_.fetch_add(1, std::memory_order_relaxed); + for (const auto &o : garbage) { + python_garbage_.push_back(nb::steal(reinterpret_cast(o.first))); + } +} + +void PythonRefManager::CollectGarbage() { + // TODO(phawkins): we should CHECK(PyGILState_Check()); + std::deque garbage; + { + absl::MutexLock lock(&mu_); + garbage_count_ = 0; + garbage.swap(python_garbage_); + } + // We defer deleting garbage until the lock is released. It's possible that + // deleting garbage will lead to more Python garbage being added; if we held + // the lock we would deadlock because absl::Mutex is not reentrant. +} + +PythonRefManager *GlobalPyRefManager() { + static PythonRefManager *static_ref_manager = new PythonRefManager(); + return static_ref_manager; +} + +} // namespace xla diff --git a/tests/ci_clangformat/python_ref_manager.h b/tests/ci_clangformat/python_ref_manager.h new file mode 100644 index 0000000..d63ba33 --- /dev/null +++ b/tests/ci_clangformat/python_ref_manager.h @@ -0,0 +1,108 @@ +/* Copyright 2019 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PYTHON_REF_MANAGER_H_ +#define JAXLIB_PYTHON_REF_MANAGER_H_ + +#include + +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/inlined_vector.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" + +namespace xla { + +// Class that manages destruction of Python objects. +// +// We must not destroy Python objects without holding the GIL. However, we +// frequently want to hold references to Python objects for the duration of +// an asynchronous transfer on a Stream, and release our reference when the +// transfer completes. +// +// This class holds references to Python objects outside a GIL scope, that can +// be collected later when the GIL is held by calling CollectGarbage(). +class PythonRefManager { + public: + PythonRefManager() = default; + + // Holds references to a set of nanobind::objects, adding the references to + // the PythonRefManager on destruction. + class ManagedPyObjects { + public: + ManagedPyObjects() = default; + ManagedPyObjects(PythonRefManager *manager, + absl::Span objects); + + ~ManagedPyObjects(); + + ManagedPyObjects(const ManagedPyObjects &other) = delete; + ManagedPyObjects(ManagedPyObjects &&other) = default; + ManagedPyObjects &operator=(const ManagedPyObjects &other) = delete; + ManagedPyObjects &operator=(ManagedPyObjects &&other) noexcept = default; + + private: + PythonRefManager *manager_ = nullptr; + absl::InlinedVector objects_; + }; + + // Creates a managed std::shared_ptr to an object. When the shared_ptr is + // destroyed, the reference to 'object' will be added to python_garbage_, + // and collected next time CollectGarbage() is called. + std::shared_ptr ManageReference(nanobind::object object); + std::shared_ptr ManageReferences( + absl::Span objects); + + // Adds garbage objects to the manager. + void AddGarbage(nanobind::object garbage); + void AddGarbage(absl::Span garbage); + void AddGarbage(absl::Span const> garbage); + + // Releases the contents of python_garbage_. Requires that the GIL is held. + // The client calls this method during API entry points where the GIL is held + // to free any garbage that has accumulated. + void CollectGarbage(); + + // Cheaper version of CollectGarbage() with relaxed consistency and frequency. + // The purpose of this function is to amortize lock acquisition costs over + // a larger number of API calls. + void MaybeCollectGarbage() { + if (garbage_count_.load(std::memory_order_relaxed) >= 100) { + CollectGarbage(); + } + } + + private: + absl::Mutex mu_; + std::deque python_garbage_ ABSL_GUARDED_BY(mu_); + + // Writes to garbage_count_ are protected by mu_, reads are not protected. + std::atomic garbage_count_{0}; +}; + +// A global PythonRefManager. Unless `CollectGarbage()` is called before +// shutdown, this container will hold on to Python objects and thus cause a +// leak. This behavior is similar to `tensorflow::ClearDecRefCache()`. +PythonRefManager *GlobalPyRefManager(); + +} // namespace xla + +#endif // JAXLIB_PYTHON_REF_MANAGER_H_ diff --git a/tests/ci_clangformat/pytree.cc b/tests/ci_clangformat/pytree.cc new file mode 100644 index 0000000..a3af748 --- /dev/null +++ b/tests/ci_clangformat/pytree.cc @@ -0,0 +1,1831 @@ +/* Copyright 2019 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Caution: this code uses exceptions. The exception use is local to the +// binding code and the idiomatic way to emit Python exceptions. + +#include "jaxlib/pytree.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/pytree.pb.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/tuple.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/pjrt/exceptions.h" +#include "xla/tsl/platform/logging.h" + +namespace xla { + +namespace nb = nanobind; + +constexpr int kSequenceKeyHashSalt = 1; +constexpr int kFlattenedIndexKeyHashSalt = 42; + +PyTreeRegistry::PyTreeRegistry(bool enable_none, bool enable_tuple, + bool enable_namedtuple, bool enable_list, + bool enable_dict) { + auto add_builtin_type = [&](PyTypeObject *type_obj, PyTreeKind kind) { + nb::object type = + nb::borrow(reinterpret_cast(type_obj)); + auto registration = std::make_unique(); + registration->kind = kind; + registration->type = type; + CHECK(registrations_.emplace(type, std::move(registration)).second); + }; + if (enable_none) { + add_builtin_type(Py_TYPE(Py_None), PyTreeKind::kNone); + } + if (enable_tuple) { + add_builtin_type(&PyTuple_Type, PyTreeKind::kTuple); + } + enable_namedtuple_ = enable_namedtuple; + if (enable_list) { + add_builtin_type(&PyList_Type, PyTreeKind::kList); + } + if (enable_dict) { + add_builtin_type(&PyDict_Type, PyTreeKind::kDict); + } +} + +void PyTreeRegistry::Register( + nb::object type, nb::callable to_iterable, nb::callable from_iterable, + std::optional to_iterable_with_keys) { + auto registration = std::make_unique(); + registration->kind = PyTreeKind::kCustom; + registration->type = type; + registration->to_iterable = std::move(to_iterable); + registration->from_iterable = std::move(from_iterable); + registration->to_iterable_with_keys = std::move(to_iterable_with_keys); + nb::ft_lock_guard lock(mu_); + auto it = registrations_.emplace(type, std::move(registration)); + if (!it.second) { + throw std::invalid_argument( + absl::StrFormat("Duplicate custom PyTreeDef type registration for %s.", + nb::cast(nb::repr(type)))); + } +} + +void PyTreeRegistry::RegisterDataclass(nb::object type, + std::vector data_fields, + std::vector meta_fields) { + auto registration = std::make_unique(); + registration->kind = PyTreeKind::kDataclass; + registration->type = type; + registration->data_fields = std::move(data_fields); + registration->meta_fields = std::move(meta_fields); + nb::ft_lock_guard lock(mu_); + auto it = registrations_.emplace(type, std::move(registration)); + if (!it.second) { + throw std::invalid_argument(absl::StrFormat( + "Duplicate custom dataclass PyTreeDef type registration for %s.", + nb::cast(nb::repr(std::move(type))))); + } +} + +std::pair +PyTreeRegistry::Registration::ToIterable(nanobind::handle o) const { + nb::object out = to_iterable(o); + nb::tuple leaves_and_aux_data; + if (!nb::try_cast(out, leaves_and_aux_data) || + leaves_and_aux_data.size() != 2) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable function for a custom PyTree node should return " + "a (children, aux_data) tuple, got ", + nb::cast(nb::repr(out)))); + } + nb::iterable leaves; + if (!nb::try_cast(leaves_and_aux_data[0], leaves)) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable function for a custom PyTree node should return " + "a (children, aux_data) tuple where 'children' is iterable, " + "got ", + nb::cast(nb::repr(out)))); + } + return std::make_pair(std::move(leaves), nb::object(leaves_and_aux_data[1])); +} + +std::pair>, nb::object> +PyTreeRegistry::Registration::ToIterableWithKeys(nb::handle o) const { + // Backwards compatibility case: return dummy FlattenedIndexKey for each leaf. + std::vector> result; + if (!to_iterable_with_keys.has_value()) { + auto [leaves, aux_data] = ToIterable(o); + for (nb::handle leaf : leaves) { + result.push_back(std::make_pair( + make_nb_class(result.size()), nb::borrow(leaf))); + } + return std::make_pair(std::move(result), std::move(aux_data)); + } + nb::object out = to_iterable_with_keys.value()(o); + nb::tuple leaves_and_aux_data; + if (!nb::try_cast(out, leaves_and_aux_data) || + leaves_and_aux_data.size() != 2) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable_with_keys function for a custom PyTree " + "node should return a (key_leaf_pairs, aux_data) tuple, got ", + nb::cast(nb::repr(out)))); + } + nb::iterable key_leaf_pairs; + if (!nb::try_cast(leaves_and_aux_data[0], key_leaf_pairs)) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable_with_keys function for a custom PyTree node should " + "return a (key_leaf_pairs, aux_data) tuple where 'key_leaf_pairs' is " + "iterable, got ", + nb::cast(nb::repr(leaves_and_aux_data)))); + } + for (nb::handle key_leaf_pair : key_leaf_pairs) { + nb::tuple key_leaf_pair_tuple; + if (!nb::try_cast(key_leaf_pair, key_leaf_pair_tuple) || + key_leaf_pair_tuple.size() != 2) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable_with_keys function for a custom PyTree node should " + "return a (key_leaf_pairs, aux_data) tuple where 'child", + nb::cast(nb::repr(key_leaf_pair)))); + } + result.push_back(std::make_pair(nb::borrow(key_leaf_pair_tuple[0]), + nb::borrow(key_leaf_pair_tuple[1]))); + } + return std::make_pair(std::move(result), nb::object(leaves_and_aux_data[1])); +} + +int PyTreeRegistry::Registration::tp_traverse(visitproc visit, void *arg) { + Py_VISIT(type.ptr()); + Py_VISIT(to_iterable.ptr()); + Py_VISIT(from_iterable.ptr()); + for (const auto &field : data_fields) { + Py_VISIT(field.ptr()); + } + for (const auto &field : meta_fields) { + Py_VISIT(field.ptr()); + } + return 0; +} + +// Computes the node kind of a given Python object. +PyTreeKind PyTreeRegistry::KindOfObject( + nb::handle obj, PyTreeRegistry::Registration const **custom) const { + const PyTreeRegistry::Registration *registration = Lookup(obj.type()); + if (registration) { + if (registration->kind == PyTreeKind::kCustom || + registration->kind == PyTreeKind::kDataclass) { + *custom = registration; + } else { + *custom = nullptr; + } + return registration->kind; + } else if (nb::isinstance(obj) && nb::hasattr(obj, "_fields")) { + // We can only identify namedtuples heuristically, here by the presence of + // a _fields attribute. + return PyTreeKind::kNamedTuple; + } else { + return PyTreeKind::kLeaf; + } +} + +/*static*/ const PyTreeRegistry::Registration *PyTreeRegistry::Lookup( + nb::handle type) const { + nb::ft_lock_guard lock(mu_); + auto it = registrations_.find(type); + return it == registrations_.end() ? nullptr : it->second.get(); +} + +/*static*/ std::vector GetSortedPyDictKeys(PyObject *py_dict) { + std::vector keys; + keys.reserve(PyDict_Size(py_dict)); + PyObject *key; + Py_ssize_t pos = 0; + while (PyDict_Next(py_dict, &pos, &key, /*value=*/nullptr)) { + keys.push_back(nb::borrow(key)); + } + + try { + std::stable_sort( + keys.begin(), keys.end(), [](const nb::object &a, const nb::object &b) { + int cmp = PyObject_RichCompareBool(a.ptr(), b.ptr(), Py_LT); + if (cmp == -1) { + throw nb::python_error(); + } + return cmp; + }); + } catch (nb::python_error &e) { + nb::raise_from(e, PyExc_ValueError, + "Comparator raised exception while sorting pytree " + "dictionary keys."); + } + return keys; +} + +/*static*/ bool IsSortedPyDictKeysEqual(absl::Span lhs, + absl::Span rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + for (int i = 0; i < lhs.size(); ++i) { + if (lhs[i].not_equal(rhs[i])) { + return false; + } + } + return true; +} + +bool PyTreeDef::operator==(const PyTreeDef &other) const { + if (traversal_.size() != other.traversal_.size()) { + return false; + } + for (size_t i = 0; i < traversal_.size(); ++i) { + const Node &a = traversal_[i]; + const Node &b = other.traversal_[i]; + if (a.kind != b.kind || a.arity != b.arity || + (a.node_data.ptr() == nullptr) != (b.node_data.ptr() == nullptr) || + (a.sorted_dict_keys.size() != b.sorted_dict_keys.size()) || + a.custom != b.custom) { + return false; + } + if (a.node_data && a.node_data.not_equal(b.node_data)) { + return false; + } + if (!IsSortedPyDictKeysEqual(a.sorted_dict_keys, b.sorted_dict_keys)) { + return false; + } + // We don't need to test equality of num_leaves and num_nodes since they + // are derivable from the other node data. + } + return true; +} + +nb::object PyTreeRegistry::FlattenOneLevel(nb::handle x) const { + return FlattenOneLevelImpl(x, /*with_keys=*/false); +} + +nb::object PyTreeRegistry::FlattenOneLevelWithKeys(nb::handle x) const { + return FlattenOneLevelImpl(x, /*with_keys=*/true); +} + +nb::object PyTreeRegistry::FlattenOneLevelImpl(nb::handle x, + bool with_keys) const { + PyTreeRegistry::Registration const *custom; + PyTreeKind kind = KindOfObject(x, &custom); + switch (kind) { + case PyTreeKind::kNone: + return nb::make_tuple(nb::make_tuple(), nb::none()); + case PyTreeKind::kTuple: { + if (with_keys) { + auto size = PyTuple_GET_SIZE(x.ptr()); + nb::object key_leaves = nb::steal(PyTuple_New(size)); + for (int i = 0; i < size; ++i) { + nb::object key = make_nb_class(i); + nb::object value = + nb::borrow(PyTuple_GET_ITEM(x.ptr(), i)); + PyTuple_SET_ITEM(key_leaves.ptr(), i, + nb::make_tuple(key, value).release().ptr()); + } + return nb::make_tuple(std::move(key_leaves), nb::none()); + } + return nb::make_tuple(nb::borrow(x), nb::none()); + } + case PyTreeKind::kList: { + if (with_keys) { + auto size = PyList_GET_SIZE(x.ptr()); + nb::object key_leaves = nb::steal(PyTuple_New(size)); + for (int i = 0; i < size; ++i) { + nb::object key = make_nb_class(i); + nb::object value = + nb::borrow(PyList_GET_ITEM(x.ptr(), i)); + PyTuple_SET_ITEM(key_leaves.ptr(), i, + nb::make_tuple(key, value).release().ptr()); + } + return nb::make_tuple(std::move(key_leaves), nb::none()); + } + return nb::make_tuple(nb::borrow(x), nb::none()); + } + case PyTreeKind::kDict: { + nb::dict dict = nb::borrow(x); + std::vector sorted_keys = GetSortedPyDictKeys(dict.ptr()); + nb::tuple keys = nb::steal(PyTuple_New(sorted_keys.size())); + nb::tuple values = nb::steal(PyTuple_New(sorted_keys.size())); + for (size_t i = 0; i < sorted_keys.size(); ++i) { + nb::object &key = sorted_keys[i]; + nb::object value = nb::object(dict[key]); + if (with_keys) { + value = nb::make_tuple(make_nb_class(key), value); + } + PyTuple_SET_ITEM(values.ptr(), i, value.release().ptr()); + PyTuple_SET_ITEM(keys.ptr(), i, sorted_keys[i].release().ptr()); + } + return nb::make_tuple(std::move(values), std::move(keys)); + } + case PyTreeKind::kNamedTuple: { + nb::tuple in = nb::borrow(x); + nb::list out; + if (with_keys) { + // Get key names from NamedTuple fields. + nb::tuple fields; + if (!nb::try_cast(nb::getattr(in, "_fields"), fields) || + in.size() != fields.size()) { + throw std::invalid_argument( + "A namedtuple's _fields attribute should have the same size as " + "the tuple."); + } + auto field_iter = fields.begin(); + for (nb::handle entry : in) { + out.append(nb::make_tuple( + make_nb_class(nb::str(*field_iter)), entry)); + } + return nb::make_tuple(std::move(out), x.type()); + } + for (size_t i = 0; i < in.size(); ++i) { + out.append(in[i]); + } + return nb::make_tuple(std::move(out), x.type()); + } + case PyTreeKind::kCustom: { + if (with_keys) { + auto [leaves, aux_data] = custom->ToIterableWithKeys(x); + return nb::make_tuple(std::move(leaves), std::move(aux_data)); + } + auto [leaves, aux_data] = custom->ToIterable(x); + return nb::make_tuple(std::move(leaves), std::move(aux_data)); + } + case PyTreeKind::kDataclass: { + auto data_size = custom->data_fields.size(); + nb::list leaves = nb::steal(PyList_New(data_size)); + for (int leaf = 0; leaf < data_size; ++leaf) { + nb::object value = nb::getattr(x, custom->data_fields[leaf]); + if (with_keys) { + value = nb::make_tuple( + make_nb_class(custom->data_fields[leaf]), value); + } + PyList_SET_ITEM(leaves.ptr(), leaf, value.release().ptr()); + } + auto meta_size = custom->meta_fields.size(); + nb::object aux_data = nb::steal(PyTuple_New(meta_size)); + for (int meta_leaf = 0; meta_leaf < meta_size; ++meta_leaf) { + PyTuple_SET_ITEM( + aux_data.ptr(), meta_leaf, + nb::getattr(x, custom->meta_fields[meta_leaf]).release().ptr()); + } + return nb::make_tuple(std::move(leaves), std::move(aux_data)); + } + default: + DCHECK(kind == PyTreeKind::kLeaf); + return nb::none(); + } +} + +/* static */ PyType_Slot PyTreeRegistry::slots_[] = { + {Py_tp_traverse, (void *)PyTreeRegistry::tp_traverse}, + {Py_tp_clear, (void *)PyTreeRegistry::tp_clear}, + {0, nullptr}, +}; + +/* static */ int PyTreeRegistry::tp_traverse(PyObject *self, visitproc visit, + void *arg) { + Py_VISIT(Py_TYPE(self)); + if (!nb::inst_ready(self)) { + return 0; + } + PyTreeRegistry *registry = nb::inst_ptr(self); + nb::ft_lock_guard lock(registry->mu_); + for (const auto &[key, value] : registry->registrations_) { + Py_VISIT(key.ptr()); + int rval = value->tp_traverse(visit, arg); + if (rval != 0) { + return rval; + } + } + return 0; +} + +/* static */ int PyTreeRegistry::tp_clear(PyObject *self) { + PyTreeRegistry *registry = nb::inst_ptr(self); + nb::ft_lock_guard lock(registry->mu_); + registry->registrations_.clear(); + return 0; +} + +/* static */ PyType_Slot DictKey::slots_[] = { + {Py_tp_traverse, (void *)DictKey::tp_traverse}, + {Py_tp_clear, (void *)DictKey::tp_clear}, + {0, nullptr}, +}; + +/* static */ int DictKey::tp_traverse(PyObject *self, visitproc visit, + void *arg) { + DictKey *key = nb::inst_ptr(self); + Py_VISIT(key->key_.ptr()); + return 0; +} + +/* static */ int DictKey::tp_clear(PyObject *self) { + DictKey *dictkey = nb::inst_ptr(self); + nb::object tmp; + std::swap(tmp, dictkey->key_); + return 0; +} + +std::string SequenceKey::ToString() const { + return absl::StrFormat("[%d]", idx_); +} + +std::string SequenceKey::ToReprString() const { + return absl::StrFormat("SequenceKey(idx=%d)", idx_); +} + +std::string DictKey::ToString() const { + return absl::StrFormat("[%s]", nb::cast(nb::repr(key_))); +} + +std::string DictKey::ToReprString() const { + return absl::StrFormat("DictKey(key=%s)", + nb::cast(nb::repr(key_))); +} + +std::string GetAttrKey::ToString() const { + return absl::StrFormat(".%s", nb::cast(name_)); +} + +std::string GetAttrKey::ToReprString() const { + return absl::StrFormat("GetAttrKey(name='%s')", + nb::cast(name_)); +} + +std::string FlattenedIndexKey::ToString() const { + return absl::StrFormat("[]", key_); +} + +std::string FlattenedIndexKey::ToReprString() const { + return absl::StrFormat("FlattenedIndexKey(key=%d)", key_); +} + +bool SequenceKey::Equals(const nb::object &other) { + SequenceKey other_key(0); + if (!nb::try_cast(other, other_key)) return false; + return idx_ == other_key.idx(); +} + +bool DictKey::Equals(const nb::object &other) { + DictKey other_key(nb::none()); + if (!nb::try_cast(other, other_key)) return false; + return key_.equal(other_key.key()); +} + +bool GetAttrKey::Equals(const nb::object &other) { + GetAttrKey other_key(nb::str("")); + if (!nb::try_cast(other, other_key)) return false; + return name_.equal(other_key.name()); +} + +bool FlattenedIndexKey::Equals(const nb::object &other) { + FlattenedIndexKey other_key(0); + if (!nb::try_cast(other, other_key)) return false; + return key_ == other_key.key(); +} + +nanobind::tuple SequenceKey::MatchArgs(nanobind::handle unused) { + return nanobind::make_tuple("idx"); +}; + +nanobind::tuple DictKey::MatchArgs(nanobind::handle unused) { + return nanobind::make_tuple("key"); +}; + +nanobind::tuple GetAttrKey::MatchArgs(nanobind::handle unused) { + return nanobind::make_tuple("name"); +}; + +nanobind::tuple FlattenedIndexKey::MatchArgs(nanobind::handle unused) { + return nanobind::make_tuple("key"); +}; + +template +void PyTreeDef::FlattenImpl(nb::handle handle, T &leaves, + const std::optional &leaf_predicate, + std::optional> &keypath) { + Node node; + const int start_num_nodes = traversal_.size(); + const int start_num_leaves = leaves.size(); + bool is_known_leaf = false; + if (leaf_predicate) { + nb::object o = (*leaf_predicate)(handle); + // Historically we accepted "truthy" values from leaf predicates. Accept + // None here to keep existing clients happy. + if (o.is_none()) { + is_known_leaf = false; + } else if (!nb::try_cast(o, is_known_leaf)) { + throw std::invalid_argument(absl::StrCat( + "is_leaf predicate returned a non-boolean value ", + nb::cast(nb::repr(o)), "; expected a boolean")); + } + } + if (is_known_leaf) { + nb::object value = nb::borrow(handle); + if (keypath.has_value()) { + const std::vector &frozen_keypath = keypath.value(); + nb::object kp_tuple = nb::steal(PyTuple_New(frozen_keypath.size())); + for (int i = 0; i < frozen_keypath.size(); ++i) { + PyTuple_SET_ITEM(kp_tuple.ptr(), i, + nb::object(frozen_keypath[i]).release().ptr()); + } + value = nb::make_tuple(std::move(kp_tuple), std::move(value)); + } + if constexpr (std::is_same_v) { + leaves.append(std::move(value)); + } else { + leaves.push_back(std::move(value)); + } + } else { + node.kind = registry_->KindOfObject(handle, &node.custom); + auto recurse = [this, &leaf_predicate, &leaves]( + nb::handle child, + std::optional> &keypath) { + if (Py_EnterRecursiveCall( + " in flatten; PyTree may have cyclical node references.")) { + return; + } + FlattenImpl(child, leaves, leaf_predicate, keypath); + Py_LeaveRecursiveCall(); + }; + switch (node.kind) { + case PyTreeKind::kNone: + // Nothing to do. + break; + case PyTreeKind::kTuple: { + node.arity = PyTuple_GET_SIZE(handle.ptr()); + for (int i = 0; i < node.arity; ++i) { + if (keypath.has_value()) { + keypath->push_back(make_nb_class(i)); + } + recurse(PyTuple_GET_ITEM(handle.ptr(), i), keypath); + if (keypath.has_value()) { + keypath->pop_back(); + } + } + break; + } + case PyTreeKind::kList: { + node.arity = PyList_GET_SIZE(handle.ptr()); + for (int i = 0; i < node.arity; ++i) { + if (keypath.has_value()) { + keypath->push_back(make_nb_class(i)); + } + recurse(PyList_GET_ITEM(handle.ptr(), i), keypath); + if (keypath.has_value()) { + keypath->pop_back(); + } + } + break; + } + case PyTreeKind::kDict: { + nb::dict dict = nb::borrow(handle); + + std::vector keys = GetSortedPyDictKeys(dict.ptr()); + for (nb::object &key : keys) { + if (keypath.has_value()) { + keypath->push_back(make_nb_class(key)); + } + recurse(dict[key], keypath); + if (keypath.has_value()) { + keypath->pop_back(); + } + } + node.arity = dict.size(); + node.sorted_dict_keys = std::move(keys); + break; + } + case PyTreeKind::kCustom: { + if (keypath.has_value()) { + auto [leaves, aux_data] = node.custom->ToIterableWithKeys(handle); + node.node_data = std::move(aux_data); + node.arity = 0; + for (auto &[key, leaf] : leaves) { + keypath->push_back(key); + ++node.arity; + recurse(leaf, keypath); + keypath->pop_back(); + } + } else { + auto [leaves, aux_data] = node.custom->ToIterable(handle); + node.node_data = std::move(aux_data); + node.arity = 0; + for (nb::handle entry : leaves) { + ++node.arity; + recurse(entry, keypath); + } + } + break; + } + case PyTreeKind::kDataclass: { + auto meta_size = node.custom->meta_fields.size(); + nb::object aux_data = nb::steal(PyTuple_New(meta_size)); + for (int meta_leaf = 0; meta_leaf < meta_size; ++meta_leaf) { + PyTuple_SET_ITEM( + aux_data.ptr(), meta_leaf, + nb::getattr(handle, node.custom->meta_fields[meta_leaf]) + .release() + .ptr()); + } + node.node_data = std::move(aux_data); + auto data_size = node.custom->data_fields.size(); + node.arity = data_size; + for (int leaf = 0; leaf < data_size; ++leaf) { + if (keypath.has_value()) { + keypath->push_back( + make_nb_class(node.custom->data_fields[leaf])); + } + recurse(nb::getattr(handle, node.custom->data_fields[leaf]), keypath); + if (keypath.has_value()) { + keypath->pop_back(); + } + } + break; + } + case PyTreeKind::kNamedTuple: { + nb::tuple tuple = nb::borrow(handle); + node.arity = tuple.size(); + node.node_data = nb::borrow(tuple.type()); + if (keypath.has_value()) { + // Get key names from NamedTuple fields. + nb::tuple fields; + if (!nb::try_cast(nb::getattr(tuple, "_fields"), fields) || + tuple.size() != fields.size()) { + throw std::invalid_argument( + "A namedtuple's _fields attribute should have the same size as " + "the tuple."); + } + auto field_iter = fields.begin(); + for (nb::handle entry : tuple) { + keypath->push_back(make_nb_class(nb::str(*field_iter))); + field_iter++; + recurse(entry, keypath); + keypath->pop_back(); + } + } else { + for (nb::handle entry : tuple) { + recurse(entry, keypath); + } + } + break; + } + default: + DCHECK(node.kind == PyTreeKind::kLeaf); + auto value = nb::borrow(handle); + if (keypath.has_value()) { + const std::vector &frozen_keypath = keypath.value(); + nb::object kp_tuple = nb::steal(PyTuple_New(frozen_keypath.size())); + for (int i = 0; i < frozen_keypath.size(); ++i) { + PyTuple_SET_ITEM(kp_tuple.ptr(), i, + nb::object(frozen_keypath[i]).release().ptr()); + } + value = nb::make_tuple(std::move(kp_tuple), std::move(value)); + } + if constexpr (std::is_same_v) { + leaves.append(std::move(value)); + } else { + leaves.push_back(std::move(value)); + } + } + } + node.num_nodes = traversal_.size() - start_num_nodes + 1; + node.num_leaves = leaves.size() - start_num_leaves; + traversal_.push_back(std::move(node)); +} + +void PyTreeDef::Flatten(nb::handle handle, + absl::InlinedVector &leaves, + std::optional leaf_predicate) { + std::optional> keypath = std::nullopt; + FlattenImpl(handle, leaves, leaf_predicate, keypath); +} + +void PyTreeDef::Flatten(nb::handle handle, std::vector &leaves, + std::optional leaf_predicate) { + std::optional> keypath = std::nullopt; + FlattenImpl(handle, leaves, leaf_predicate, keypath); +} + +void PyTreeDef::Flatten(nb::handle handle, nb::list &leaves, + std::optional leaf_predicate) { + std::optional> keypath = std::nullopt; + FlattenImpl(handle, leaves, leaf_predicate, keypath); +} + +/*static*/ std::pair, nb_class_ptr> +PyTreeDef::Flatten(nb::handle x, nb_class_ptr registry, + std::optional leaf_predicate) { + auto def = make_nb_class(registry); + std::vector leaves; + def->Flatten(x, leaves, leaf_predicate); + return std::make_pair(std::move(leaves), std::move(def)); +} + +void PyTreeDef::FlattenWithPath(nb::handle handle, nanobind::list &leaves, + std::optional leaf_predicate) { + std::optional> keypath = std::vector(); + FlattenImpl(handle, leaves, leaf_predicate, keypath); +} + +/*static*/ bool PyTreeDef::AllLeaves(PyTreeRegistry *registry, + const nb::iterable &x) { + const PyTreeRegistry::Registration *custom; + for (const nb::handle &h : x) { + if (registry->KindOfObject(h, &custom) != PyTreeKind::kLeaf) return false; + } + return true; +} + +template +nb::object PyTreeDef::UnflattenImpl(T leaves) const { + absl::InlinedVector agenda; + auto it = leaves.begin(); + int leaf_count = 0; + for (const Node &node : traversal_) { + if (agenda.size() < node.arity) { + throw std::logic_error("Too few elements for TreeDef node."); + } + switch (node.kind) { + case PyTreeKind::kLeaf: + if (it == leaves.end()) { + throw std::invalid_argument(absl::StrFormat( + "Too few leaves for PyTreeDef; expected %d, got %d", num_leaves(), + leaf_count)); + } + agenda.push_back(nb::borrow(*it)); + ++it; + ++leaf_count; + break; + + case PyTreeKind::kNone: + case PyTreeKind::kTuple: + case PyTreeKind::kNamedTuple: + case PyTreeKind::kList: + case PyTreeKind::kDict: + case PyTreeKind::kCustom: + case PyTreeKind::kDataclass: { + const int size = agenda.size(); + absl::Span span; + if (node.arity > 0) { + span = absl::Span(&agenda[size - node.arity], node.arity); + } + nb::object o = MakeNode(node, span); + agenda.resize(size - node.arity); + agenda.push_back(o); + break; + } + } + } + if (it != leaves.end()) { + throw std::invalid_argument(absl::StrFormat( + "Too many leaves for PyTreeDef; expected %d.", num_leaves())); + } + if (agenda.size() != 1) { + throw std::logic_error("PyTreeDef traversal did not yield a singleton."); + } + return std::move(agenda.back()); +} + +nb::object PyTreeDef::Unflatten(nb::iterable leaves) const { + return UnflattenImpl(leaves); +} + +nb::object PyTreeDef::Unflatten(absl::Span leaves) const { + return UnflattenImpl(leaves); +} + +/*static*/ nb::object PyTreeDef::MakeNode(const PyTreeDef::Node &node, + absl::Span children) { + if (children.size() != node.arity) { + throw std::logic_error("Node arity mismatch."); + } + switch (node.kind) { + case PyTreeKind::kLeaf: + throw std::logic_error("MakeNode not implemented for leaves."); + + case PyTreeKind::kNone: + return nb::none(); + + case PyTreeKind::kTuple: + case PyTreeKind::kNamedTuple: { + nb::object tuple = nb::steal(PyTuple_New(node.arity)); + for (int i = 0; i < node.arity; ++i) { + PyTuple_SET_ITEM(tuple.ptr(), i, children[i].release().ptr()); + } + if (node.kind == PyTreeKind::kNamedTuple) { + return node.node_data(*tuple); + } else { + return tuple; + } + } + + case PyTreeKind::kList: { + nb::object list = nb::steal(PyList_New(node.arity)); + for (int i = 0; i < node.arity; ++i) { + PyList_SET_ITEM(list.ptr(), i, children[i].release().ptr()); + } + return list; + } + + case PyTreeKind::kDict: { + nb::dict dict; + for (int i = 0; i < node.arity; ++i) { + dict[node.sorted_dict_keys[i]] = std::move(children[i]); + } + return std::move(dict); + break; + } + case PyTreeKind::kCustom: { + nb::object tuple = nb::steal(PyTuple_New(node.arity)); + for (int i = 0; i < node.arity; ++i) { + PyTuple_SET_ITEM(tuple.ptr(), i, children[i].release().ptr()); + } + return node.custom->from_iterable(node.node_data, tuple); + } + + case PyTreeKind::kDataclass: { + nb::kwargs kwargs; + auto meta_size = node.custom->meta_fields.size(); + for (int i = 0; i < meta_size; ++i) { + kwargs[node.custom->meta_fields[i]] = + nb::borrow(nb::tuple(node.node_data)[i]); + } + auto data_size = node.custom->data_fields.size(); + for (int i = 0; i < data_size; ++i) { + kwargs[node.custom->data_fields[i]] = std::move(children[i]); + } + return node.custom->type(**kwargs); + } + } + throw std::logic_error("Unreachable code."); +} + +nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { + nb::list leaves = nb::steal(PyList_New(num_leaves())); + std::vector agenda; + agenda.push_back(nb::borrow(xs)); + auto it = traversal_.rbegin(); + int leaf = num_leaves() - 1; + while (!agenda.empty()) { + if (it == traversal_.rend()) { + throw std::invalid_argument(absl::StrFormat( + "Tree structures did not match: %s vs %s", + nb::cast(nb::repr(xs)), ToString())); + } + const Node &node = *it; + nb::object object = agenda.back(); + agenda.pop_back(); + ++it; + + switch (node.kind) { + case PyTreeKind::kLeaf: + if (leaf < 0) { + throw std::logic_error("Leaf count mismatch."); + } + PyList_SET_ITEM(leaves.ptr(), leaf, object.release().ptr()); + --leaf; + break; + + case PyTreeKind::kNone: + if (!object.is_none()) { + throw std::invalid_argument(absl::StrFormat( + "Expected None, got %s.\n\n" + "In previous releases of JAX, flatten-up-to used to " + "consider None to be a tree-prefix of non-None values. To obtain " + "the previous behavior, you can usually write:\n" + " jax.tree.map(lambda x, y: None if x is None else f(x, y), a, " + "b, is_leaf=lambda x: x is None)", + nb::cast(nb::repr(object)))); + } + break; + + case PyTreeKind::kTuple: { + if (!PyTuple_CheckExact(object.ptr())) { + throw std::invalid_argument( + absl::StrFormat("Expected tuple, got %s.", + nb::cast(nb::repr(object)))); + } + nb::tuple tuple = nb::borrow(object); + if (tuple.size() != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "Tuple arity mismatch: %d != %d; tuple: %s.", tuple.size(), + node.arity, nb::cast(nb::repr(object)))); + } + for (nb::handle entry : tuple) { + agenda.push_back(nb::borrow(entry)); + } + break; + } + + case PyTreeKind::kList: { + if (!PyList_CheckExact(object.ptr())) { + throw std::invalid_argument( + absl::StrFormat("Expected list, got %s.", + nb::cast(nb::repr(object)))); + } + nb::list list = nb::borrow(object); + if (list.size() != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "List arity mismatch: %d != %d; list: %s.", list.size(), + node.arity, nb::cast(nb::repr(object)))); + } + for (nb::handle entry : list) { + agenda.push_back(nb::borrow(entry)); + } + break; + } + + case PyTreeKind::kDict: { + if (!PyDict_CheckExact(object.ptr())) { + throw std::invalid_argument( + absl::StrFormat("Expected dict, got %s.", + nb::cast(nb::repr(object)))); + } + nb::dict dict = nb::borrow(object); + std::vector keys = GetSortedPyDictKeys(dict.ptr()); + if (!IsSortedPyDictKeysEqual(keys, node.sorted_dict_keys)) { + // Convert to a nb::list for nb::repr to avoid having to stringify a + // vector. This is error path so it is fine to pay conversion cost. + throw std::invalid_argument( + absl::StrFormat("Dict key mismatch; expected keys: %s; dict: %s.", + nb::cast( + nb::repr(nb::cast(node.sorted_dict_keys))), + nb::cast(nb::repr(object)))); + } + for (nb::handle key : keys) { + agenda.push_back(dict[key]); + } + break; + } + + case PyTreeKind::kNamedTuple: { + if (!nb::isinstance(object) || + !nb::hasattr(object, "_fields")) { + throw std::invalid_argument( + absl::StrFormat("Expected named tuple, got %s.", + nb::cast(nb::repr(object)))); + } + nb::tuple tuple = nb::borrow(object); + if (tuple.size() != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "Named tuple arity mismatch: %d != %d; tuple: %s.", tuple.size(), + node.arity, nb::cast(nb::repr(object)))); + } + if (tuple.type().not_equal(node.node_data)) { + throw std::invalid_argument(absl::StrFormat( + "Named tuple type mismatch: expected type: %s, tuple: %s.", + nb::cast(nb::repr(node.node_data)), + nb::cast(nb::repr(object)))); + } + for (nb::handle entry : tuple) { + agenda.push_back(nb::borrow(entry)); + } + break; + } + + case PyTreeKind::kCustom: { + auto *registration = registry_->Lookup(object.type()); + if (registration != node.custom) { + throw std::invalid_argument(absl::StrFormat( + "Custom node type mismatch: expected type: %s, value: %s.", + nb::cast(nb::repr(node.custom->type)), + nb::cast(nb::repr(object)))); + } + auto [leaves, aux_data] = node.custom->ToIterable(object); + if (node.node_data.not_equal(aux_data)) { + throw std::invalid_argument(absl::StrFormat( + "Mismatch custom node data: %s != %s; value: %s.", + nb::cast(nb::repr(node.node_data)), + nb::cast(nb::repr(aux_data)), + nb::cast(nb::repr(object)))); + } + int arity = 0; + for (nb::handle entry : leaves) { + ++arity; + agenda.push_back(nb::borrow(entry)); + } + if (arity != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "Custom type arity mismatch: %d != %d; value: %s.", arity, + node.arity, nb::cast(nb::repr(object)))); + } + break; + } + + case PyTreeKind::kDataclass: { + auto *registration = registry_->Lookup(object.type()); + if (registration != node.custom) { + throw std::invalid_argument(absl::StrFormat( + "Custom dataclass node type mismatch: expected type: %s, value: " + "%s.", + nb::cast(nb::repr(node.custom->type)), + nb::cast(nb::repr(std::move(object))))); + } + auto meta_size = node.custom->meta_fields.size(); + nb::object aux_data = nb::steal(PyTuple_New(meta_size)); + for (int meta_leaf = 0; meta_leaf < meta_size; ++meta_leaf) { + PyTuple_SET_ITEM( + aux_data.ptr(), meta_leaf, + nb::getattr(object, node.custom->meta_fields[meta_leaf]) + .release() + .ptr()); + } + if (node.node_data.not_equal(aux_data)) { + throw std::invalid_argument(absl::StrFormat( + "Mismatch custom dataclass node data: %s != %s; value: %s.", + nb::cast(nb::repr(node.node_data)), + nb::cast(nb::repr(aux_data)), + nb::cast(nb::repr(object)))); + } + auto data_size = node.custom->data_fields.size(); + if (data_size != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "Custom type arity mismatch: %d != %d; value: %s.", data_size, + node.arity, nb::cast(nb::repr(object)))); + } + for (int leaf = 0; leaf < data_size; ++leaf) { + agenda.push_back(nb::borrow( + nb::getattr(object, node.custom->data_fields[leaf]))); + } + break; + } + } + } + if (it != traversal_.rend() || leaf != -1) { + throw std::invalid_argument( + absl::StrFormat("Tree structures did not match: %s vs %s", + nb::cast(nb::repr(xs)), ToString())); + } + return leaves; +} + +nb::object PyTreeDef::Walk(const nb::callable &f_node, nb::handle f_leaf, + nb::iterable leaves) const { + std::vector agenda; + auto it = leaves.begin(); + for (const Node &node : traversal_) { + switch (node.kind) { + case PyTreeKind::kLeaf: { + if (it == leaves.end()) { + throw std::invalid_argument("Too few leaves for PyTreeDef"); + } + + nb::object leaf = nb::borrow(*it); + agenda.push_back(f_leaf.is_none() ? std::move(leaf) + : f_leaf(std::move(leaf))); + ++it; + break; + } + + case PyTreeKind::kNone: + case PyTreeKind::kTuple: + case PyTreeKind::kNamedTuple: + case PyTreeKind::kList: + case PyTreeKind::kDict: + case PyTreeKind::kCustom: + case PyTreeKind::kDataclass: { + if (agenda.size() < node.arity) { + throw std::logic_error("Too few elements for custom type."); + } + nb::object tuple = nb::steal(PyTuple_New(node.arity)); + for (int i = node.arity - 1; i >= 0; --i) { + PyTuple_SET_ITEM(tuple.ptr(), i, agenda.back().release().ptr()); + agenda.pop_back(); + } + nb::object node_data = node.node_data; + if (node.kind == PyTreeKind::kDict) { + // Convert to a nb::list for f_node invocation. + node_data = nb::cast(node.sorted_dict_keys); + } + agenda.push_back(f_node(tuple, node_data ? node_data : nb::none())); + } + } + } + if (it != leaves.end()) { + throw std::invalid_argument("Too many leaves for PyTreeDef"); + } + if (agenda.size() != 1) { + throw std::logic_error("PyTreeDef traversal did not yield a singleton."); + } + return std::move(agenda.back()); +} + +nb::object PyTreeDef::FromIterableTreeHelper( + nb::handle xs, + absl::InlinedVector::const_reverse_iterator *it) const { + if (*it == traversal_.rend()) { + throw std::invalid_argument("Tree structures did not match."); + } + const Node &node = **it; + ++*it; + if (node.kind == PyTreeKind::kLeaf) { + return nb::borrow(xs); + } + nb::iterable iterable = nb::borrow(xs); + std::vector ys; + ys.reserve(node.arity); + for (nb::handle x : iterable) { + ys.push_back(nb::borrow(x)); + } + if (ys.size() != node.arity) { + throw std::invalid_argument("Arity mismatch between trees"); + } + for (int j = node.arity - 1; j >= 0; --j) { + ys[j] = FromIterableTreeHelper(ys[j], it); + } + + return MakeNode(node, absl::MakeSpan(ys)); +} + +nb::object PyTreeDef::FromIterableTree(nb::handle xs) const { + auto it = traversal_.rbegin(); + nb::object out = FromIterableTreeHelper(xs, &it); + if (it != traversal_.rend()) { + throw std::invalid_argument("Tree structures did not match."); + } + return out; +} + +nb_class_ptr PyTreeDef::Compose(const PyTreeDef &inner) const { + if (inner.registry_ != registry_) { + throw std::invalid_argument( + "PyTree registries of PyTreeDefs passed to Compose() must match."); + } + auto out = make_nb_class(registry_ref_); + out->traversal_.reserve(static_cast(num_leaves()) * + inner.num_nodes() + + num_nodes() - num_leaves()); + for (const Node &n : traversal_) { + if (n.kind == PyTreeKind::kLeaf) { + absl::c_copy(inner.traversal_, std::back_inserter(out->traversal_)); + } else { + out->traversal_.push_back(n); + } + } + out->SetNumLeavesAndNumNodes(); + return out; +} + +/*static*/ nb_class_ptr PyTreeDef::Tuple( + nb_class_ptr registry, nb::list defs) { + auto out = make_nb_class(std::move(registry)); + int num_leaves = 0; + for (nb::handle def_handle : defs) { + const PyTreeDef *def = nb::cast(def_handle); + if (def->registry() != out->registry()) { + throw std::invalid_argument( + "PyTree registries of PyTreeDefs passed to Tuple() must match."); + } + absl::c_copy(def->traversal_, std::back_inserter(out->traversal_)); + num_leaves += def->num_leaves(); + } + Node node; + node.kind = PyTreeKind::kTuple; + node.arity = defs.size(); + node.num_leaves = num_leaves; + node.num_nodes = out->traversal_.size() + 1; + out->traversal_.push_back(node); + return out; +} + +std::vector> PyTreeDef::Children() const { + std::vector> children; + if (traversal_.empty()) { + return children; + } + Node const &root = traversal_.back(); + children.resize(root.arity); + int pos = traversal_.size() - 1; + for (int i = root.arity - 1; i >= 0; --i) { + children[i] = make_nb_class(registry_ref_); + const Node &node = traversal_.at(pos - 1); + if (pos < node.num_nodes) { + throw std::logic_error("children() walked off start of array"); + } + std::copy(traversal_.begin() + pos - node.num_nodes, + traversal_.begin() + pos, + std::back_inserter(children[i]->traversal_)); + pos -= node.num_nodes; + } + if (pos != 0) { + throw std::logic_error("pos != 0 at end of PyTreeDef::Children"); + } + return children; +} + +std::string PyTreeDef::ToString() const { + std::vector agenda; + for (const Node &node : traversal_) { + if (agenda.size() < node.arity) { + throw std::logic_error("Too few elements for container."); + } + + std::string children = + absl::StrJoin(agenda.end() - node.arity, agenda.end(), ", "); + std::string representation; + switch (node.kind) { + case PyTreeKind::kLeaf: + agenda.push_back("*"); + continue; + case PyTreeKind::kNone: + representation = "None"; + break; + case PyTreeKind::kTuple: + // Tuples with only one element must have a trailing comma. + if (node.arity == 1) children += ","; + representation = absl::StrCat("(", children, ")"); + break; + case PyTreeKind::kList: + representation = absl::StrCat("[", children, "]"); + break; + case PyTreeKind::kDict: { + if (node.sorted_dict_keys.size() != node.arity) { + throw std::logic_error("Number of keys and entries does not match."); + } + representation = "{"; + std::string separator; + auto child_iter = agenda.end() - node.arity; + for (const nb::handle &key : node.sorted_dict_keys) { + absl::StrAppendFormat(&representation, "%s%s: %s", separator, + nb::cast(nb::repr(key)), + *child_iter); + child_iter++; + separator = ", "; + } + representation += "}"; + break; + } + + case PyTreeKind::kNamedTuple: + case PyTreeKind::kCustom: + case PyTreeKind::kDataclass: { + std::string kind; + std::string data; + if (node.kind == PyTreeKind::kNamedTuple) { + kind = "namedtuple"; + if (node.node_data) { + // Node data for named tuples is the type. + data = absl::StrFormat( + "[%s]", nb::cast( + nb::str(nb::getattr(node.node_data, "__name__")))); + } + } else { + kind = nb::cast( + nb::str(nb::getattr(node.custom->type, "__name__"))); + if (node.node_data) { + data = absl::StrFormat( + "[%s]", nb::cast(nb::str(node.node_data))); + } + } + + representation = + absl::StrFormat("CustomNode(%s%s, [%s])", kind, data, children); + break; + } + } + agenda.erase(agenda.end() - node.arity, agenda.end()); + agenda.push_back(std::move(representation)); + } + if (agenda.size() != 1) { + throw std::logic_error("PyTreeDef traversal did not yield a singleton."); + } + return absl::StrCat("PyTreeDef(", agenda.back(), ")"); +} + +nb::object PyTreeDef::ToPickle() const { + nb::list traversal; + for (const auto &node : traversal_) { + nb::object node_data = node.node_data; + if (node.kind == PyTreeKind::kDict) { + // Convert to a nb::list for pickling to avoid having to pickle a vector. + // Pickle should be a rare operation so this conversion cost is hopefully + // on non-critical path. + node_data = nb::cast(node.sorted_dict_keys); + } + traversal.append( + nb::make_tuple(static_cast(node.kind), node.arity, + node_data ? node_data : nb::none(), + node.custom != nullptr ? node.custom->type : nb::none(), + node.num_leaves, node.num_nodes)); + } + return nb::make_tuple(nb::cast(registry_ref_), traversal); +} + +void PyTreeDef::FromPickle(nb::object pickle) { + for (const auto &item : nb::cast(pickle)) { + auto t = nb::cast(item); + if (t.size() != 6) { + throw xla::XlaRuntimeError("Malformed pickled PyTreeDef"); + } + Node &node = traversal_.emplace_back(); + node.kind = static_cast(nb::cast(t[0])); + node.arity = nb::cast(t[1]); + switch (node.kind) { + case PyTreeKind::kNamedTuple: + node.node_data = t[2]; + break; + case PyTreeKind::kDict: + node.sorted_dict_keys = nb::cast>(t[2]); + break; + case PyTreeKind::kCustom: + case PyTreeKind::kDataclass: + node.node_data = t[2]; + break; + default: + if (!t[2].is_none()) { + throw xla::XlaRuntimeError("Malformed pickled PyTreeDef"); + } + break; + } + if (node.kind == PyTreeKind::kCustom || + node.kind == PyTreeKind::kDataclass) { + node.custom = t[3].is_none() ? nullptr : registry()->Lookup(t[3]); + if (node.custom == nullptr) { + throw xla::XlaRuntimeError( + absl::StrCat("Unknown custom type in pickled PyTreeDef: ", + nb::cast(nb::repr(t[3])))); + } + } else { + if (!t[3].is_none()) { + throw xla::XlaRuntimeError("Malformed pickled PyTreeDef"); + } + } + node.num_leaves = nb::cast(t[4]); + node.num_nodes = nb::cast(t[5]); + } +} + +void PyTreeDef::SetNumLeavesAndNumNodes() { + // num_leaves and num_nodes are fully determined by arity. + std::vector> starts; + int num_leaves = 0; + for (int i = 0; i < traversal_.size(); ++i) { + std::pair start = {num_leaves, i}; + if (traversal_[i].kind == PyTreeKind::kLeaf) { + num_leaves += 1; + } + if (traversal_[i].arity == 0) { + starts.push_back(start); + } else { + starts.resize(starts.size() - (traversal_[i].arity - 1)); + } + traversal_[i].num_leaves = num_leaves - starts.back().first; + traversal_[i].num_nodes = i + 1 - starts.back().second; + } +} + +void PyTreeDef::SerializeTo(jax::PyTreeDefProto &result) const { + absl::flat_hash_map interned_strings; + auto intern_str = [&](const std::string &key) { + auto [it, added] = + interned_strings.emplace(key, result.interned_strings_size()); + if (added) { + result.add_interned_strings(key); + } + return it->second; + }; + for (const auto &node : traversal_) { + auto *node_data = result.add_nodes(); + node_data->set_arity(node.arity); + switch (node.kind) { + case PyTreeKind::kLeaf: + node_data->set_type(jax::PyTreeNodeType::PY_TREE_KIND_LEAF); + break; + case PyTreeKind::kList: + node_data->set_type(jax::PyTreeNodeType::PY_TREE_KIND_LIST); + break; + case PyTreeKind::kNone: + node_data->set_type(jax::PyTreeNodeType::PY_TREE_KIND_NONE); + break; + case PyTreeKind::kTuple: + node_data->set_type(jax::PyTreeNodeType::PY_TREE_KIND_TUPLE); + break; + case PyTreeKind::kDict: + node_data->set_type(jax::PyTreeNodeType::PY_TREE_KIND_DICT); + for (auto &key : node.sorted_dict_keys) { + if (!nb::isinstance(key)) { + throw std::invalid_argument( + "Only string keys are supported in proto pytree " + "serialization."); + } + node_data->mutable_dict_keys()->add_str_id( + intern_str(nb::cast(key))); + } + break; + default: + throw std::invalid_argument( + "User-defined nodes are not supported when serializing pytrees as " + "protocol buffers. You should either convert the user-defined " + "nodes to another type or use pickle instead."); + break; + } + } +} + +nb_class_ptr PyTreeDef::DeserializeFrom( + nb_class_ptr registry, const jax::PyTreeDefProto &input) { + std::vector interned_strings; + interned_strings.reserve(input.interned_strings().size()); + for (auto &s : input.interned_strings()) { + interned_strings.push_back(nb::cast(s)); + } + nb_class_ptr result = + make_nb_class(std::move(registry)); + for (auto &node_proto : input.nodes()) { + result->traversal_.emplace_back(); + auto &node = result->traversal_.back(); + node.arity = node_proto.arity(); + node.custom = nullptr; + switch (node_proto.type()) { + case jax::PyTreeNodeType::PY_TREE_KIND_LEAF: + node.kind = PyTreeKind::kLeaf; + break; + case jax::PyTreeNodeType::PY_TREE_KIND_LIST: + node.kind = PyTreeKind::kList; + break; + case jax::PyTreeNodeType::PY_TREE_KIND_NONE: + node.kind = PyTreeKind::kNone; + break; + case jax::PyTreeNodeType::PY_TREE_KIND_TUPLE: + node.kind = PyTreeKind::kTuple; + break; + case jax::PyTreeNodeType::PY_TREE_KIND_DICT: + node.kind = PyTreeKind::kDict; + for (uint32_t str_id : node_proto.dict_keys().str_id()) { + if (str_id >= interned_strings.size()) { + throw std::invalid_argument( + "Malformed pytree proto (dict_key out of range)."); + } + node.sorted_dict_keys.push_back(interned_strings.at(str_id)); + } + break; + default: + throw std::invalid_argument( + "Malformed pytree proto (invalid node type)"); + break; + } + } + result->SetNumLeavesAndNumNodes(); + return result; +} + +std::optional> PyTreeDef::GetNodeData() + const { + if (traversal_.empty()) { + throw std::logic_error("empty PyTreeDef traversal."); + } + auto builtin_type = [](PyTypeObject *type_obj) { + return nb::borrow(reinterpret_cast(type_obj)); + }; + const auto &node = traversal_.back(); + switch (node.kind) { + case PyTreeKind::kLeaf: + return std::nullopt; + case PyTreeKind::kNone: + return std::make_pair(builtin_type(Py_TYPE(Py_None)), nb::none()); + case PyTreeKind::kTuple: + return std::make_pair(builtin_type(&PyTuple_Type), nb::none()); + case PyTreeKind::kList: + return std::make_pair(builtin_type(&PyList_Type), nb::none()); + case PyTreeKind::kDict: + return std::make_pair(builtin_type(&PyDict_Type), + nb::cast(node.sorted_dict_keys)); + case PyTreeKind::kNamedTuple: + return std::make_pair(node.node_data, nb::none()); + case PyTreeKind::kCustom: + case PyTreeKind::kDataclass: + return std::make_pair(node.custom->type, node.node_data); + } +} + +nb_class_ptr PyTreeDef::MakeFromNodeDataAndChildren( + nb_class_ptr registry, + std::optional> node_data, + nb::iterable children) { + nb_class_ptr result = + make_nb_class(std::move(registry)); + int num_leaves = 0; + int arity = 0; + for (nb::handle pchild : children) { + const PyTreeDef &child = nb::cast(pchild); + absl::c_copy(child.traversal_, std::back_inserter(result->traversal_)); + num_leaves += child.num_leaves(); + ++arity; + } + result->traversal_.emplace_back(); + auto &node = result->traversal_.back(); + node.arity = arity; + node.custom = nullptr; + node.num_leaves = num_leaves; + node.num_nodes = result->traversal_.size(); + if (node_data == std::nullopt) { + node.kind = PyTreeKind::kLeaf; + ++node.num_leaves; + return result; + } + int is_nt = PyObject_IsSubclass(node_data->first.ptr(), + reinterpret_cast(&PyTuple_Type)); + if (is_nt == -1) { + throw nb::python_error(); + } + if (is_nt != 0 && nb::hasattr(node_data->first, "_fields")) { + node.kind = PyTreeKind::kNamedTuple; + node.node_data = node_data->first; + return result; + } + auto *registration = result->registry()->Lookup(node_data->first); + if (registration == nullptr) { + throw std::logic_error(absl::StrFormat( + "Could not find type: %s.", + nb::cast(nb::repr(node_data->first)))); + } + node.kind = registration->kind; + if (node.kind == PyTreeKind::kCustom || node.kind == PyTreeKind::kDataclass) { + node.custom = registration; + node.node_data = node_data->second; + } else if (node.kind == PyTreeKind::kNamedTuple) { + node.node_data = node_data->first; + } else if (node.kind == PyTreeKind::kDict) { + node.sorted_dict_keys = + nb::cast>(node_data->second); + } + return result; +} + +int PyTreeDef::Node::tp_traverse(visitproc visit, void *arg) const { + Py_VISIT(node_data.ptr()); + for (const auto &key : sorted_dict_keys) { + Py_VISIT(key.ptr()); + } + return 0; +} + +/* static */ int PyTreeDef::tp_traverse(PyObject *self, visitproc visit, + void *arg) { + Py_VISIT(Py_TYPE(self)); + if (!nb::inst_ready(self)) { + return 0; + } + PyTreeDef *treedef = nb::inst_ptr(self); + Py_VISIT(treedef->registry_ref_.ptr()); + for (const auto &node : treedef->traversal_) { + node.tp_traverse(visit, arg); + } + return 0; +} + +/* static */ int PyTreeDef::tp_clear(PyObject *self) { + PyTreeDef *treedef = nb::inst_ptr(self); + treedef->registry_ref_.reset(); + treedef->traversal_.clear(); + return 0; +} + +/* static */ PyType_Slot PyTreeDef::slots_[] = { + {Py_tp_traverse, (void *)PyTreeDef::tp_traverse}, + {Py_tp_clear, (void *)PyTreeDef::tp_clear}, + {0, nullptr}, +}; + +void BuildPytreeSubmodule(nb::module_ &m) { + nb::module_ pytree = m.def_submodule("pytree", "Python tree library"); + pytree.attr("version") = nb::int_(3); + + nb::class_ treedef(pytree, "PyTreeDef", + nb::type_slots(PyTreeDef::slots_)); + + nb::class_ registry(m, "PyTreeRegistry", nb::dynamic_attr(), + nb::type_slots(PyTreeRegistry::slots_)); + + registry.def(nb::init(), + nb::arg("enable_none") = true, nb::arg("enable_tuple") = true, + nb::arg("enable_namedtuple") = true, + nb::arg("enable_list") = true, nb::arg("enable_dict") = true); + registry.def( + "flatten", + [](nb_class_ptr registry, nb::object x, + std::optional leaf_predicate) { + nb::list leaves; + nb_class_ptr def = + make_nb_class(std::move(registry)); + def->Flatten(x, leaves, leaf_predicate); + return nb::make_tuple(std::move(leaves), std::move(def)); + }, + nb::arg("tree").none(), nb::arg("leaf_predicate").none() = std::nullopt); + registry.def("flatten_one_level", &PyTreeRegistry::FlattenOneLevel, + nb::arg("tree").none()); + registry.def("flatten_one_level_with_keys", + &PyTreeRegistry::FlattenOneLevelWithKeys, + nb::arg("tree").none()); + registry.def( + "flatten_with_path", + [](nb_class_ptr registry, nb::object x, + std::optional leaf_predicate) { + nb::list leaves; + nb_class_ptr def = + make_nb_class(std::move(registry)); + def->FlattenWithPath(x, leaves, leaf_predicate); + return nb::make_tuple(std::move(leaves), std::move(def)); + }, + nb::arg("tree").none(), nb::arg("leaf_predicate").none() = std::nullopt); + registry.def("register_node", &PyTreeRegistry::Register, + nb::arg("type").none(), nb::arg("to_iterable").none(), + nb::arg("from_iterable").none(), + nb::arg("to_iterable_with_keys").none() = std::nullopt); + registry.def("register_dataclass_node", &PyTreeRegistry::RegisterDataclass); + registry.def("__reduce__", + [](nb::object self) { return self.attr("__name__"); }); + + pytree.attr("_default_registry") = make_nb_class( + /*enable_none=*/true, /*enable_tuple=*/true, /*enable_namedtuple=*/true, + /*enable_list=*/true, /*enable_dict*/ true); + pytree.def("default_registry", + [registry = nb::cast>( + pytree.attr("_default_registry"))]() { return registry; }); + + pytree.attr("PyTreeRegistry") = m.attr("PyTreeRegistry"); + pytree.def("tuple", &PyTreeDef::Tuple); + pytree.def("all_leaves", &PyTreeDef::AllLeaves); + + treedef.def("unflatten", + static_cast( + &PyTreeDef::Unflatten)); + treedef.def("flatten_up_to", &PyTreeDef::FlattenUpTo, nb::arg("tree").none()); + treedef.def("compose", &PyTreeDef::Compose); + treedef.def( + "walk", &PyTreeDef::Walk, + "Walk pytree, calling f_node(node, node_data) at nodes, and f_leaf " + "at leaves", + nb::arg("f_node"), nb::arg("f_leaf"), nb::arg("leaves")); + treedef.def("from_iterable_tree", &PyTreeDef::FromIterableTree); + treedef.def("children", &PyTreeDef::Children); + treedef.def_prop_ro("num_leaves", &PyTreeDef::num_leaves); + treedef.def_prop_ro("num_nodes", &PyTreeDef::num_nodes); + treedef.def("__repr__", &PyTreeDef::ToString); + treedef.def("__eq__", + [](const PyTreeDef &a, const PyTreeDef &b) { return a == b; }); + treedef.def("__ne__", + [](const PyTreeDef &a, const PyTreeDef &b) { return a != b; }); + treedef.def("__hash__", [](const PyTreeDef &t) { return absl::HashOf(t); }); + treedef.def("serialize_using_proto", [](const PyTreeDef &a) { + jax::PyTreeDefProto result; + a.SerializeTo(result); + std::string serialized = result.SerializeAsString(); + return nb::bytes(serialized.data(), serialized.size()); + }); + treedef.def_static( + "deserialize_using_proto", + [](nb_class_ptr registry, nb::bytes data) { + jax::PyTreeDefProto input; + absl::string_view serialized(data.c_str(), data.size()); + if (serialized.size() > std::numeric_limits::max()) { + throw xla::XlaRuntimeError( + "Pytree serialization too large to deserialize."); + } + if (!input.ParseFromArray(serialized.data(), serialized.size())) { + throw xla::XlaRuntimeError("Could not deserialize PyTreeDefProto."); + } + return PyTreeDef::DeserializeFrom(std::move(registry), input); + }, + nb::arg("registry"), nb::arg("data")); + treedef.def("node_data", &PyTreeDef::GetNodeData, + "Returns None if a leaf-pytree, else (type, node_data)"); + treedef.def_static( + "make_from_node_data_and_children", + &PyTreeDef::MakeFromNodeDataAndChildren, nb::arg("registry"), + nb::arg("node_data").none(), nb::arg("children"), + "Reconstructs a pytree from `node_data()` and `children()`."); + treedef.def("__getstate__", &PyTreeDef::ToPickle); + treedef.def("__setstate__", [](PyTreeDef &t, nb::object o) { + nb::tuple pickle = nb::cast(o); + if (pickle.size() != 2) { + throw xla::XlaRuntimeError( + "Malformed pickled PyTreeDef, expected 2-tuple"); + } + auto registry = nb::cast>(pickle[0]); + new (&t) PyTreeDef(registry); + t.FromPickle(pickle[1]); + }); + + nb::class_ sequence_key(pytree, "SequenceKey"); + sequence_key.def(nb::init(), nb::arg("idx")); + sequence_key.def("__str__", &SequenceKey::ToString); + sequence_key.def("__repr__", &SequenceKey::ToReprString); + sequence_key.def("__eq__", &SequenceKey::Equals); + sequence_key.def("__hash__", [](const SequenceKey &key) { + return key.idx() + kSequenceKeyHashSalt; + }); + sequence_key.def_prop_ro("idx", &SequenceKey::idx); + sequence_key.def_prop_ro_static("__match_args__", &SequenceKey::MatchArgs); + sequence_key.def("__getstate__", + [](SequenceKey &key) { return nb::make_tuple(key.idx()); }); + sequence_key.def("__setstate__", + [](SequenceKey &key, const nb::tuple &state) { + if (state.size() != 1) { + throw xla::XlaRuntimeError( + "Malformed pickled SequenceKey, expected 1-tuple"); + } + new (&key) SequenceKey(nb::cast(state[0])); + }); + + nb::class_ dict_key(pytree, "DictKey", + nb::type_slots(DictKey::slots_)); + dict_key.def(nb::init(), nb::arg("key")); + dict_key.def("__str__", &DictKey::ToString); + dict_key.def("__repr__", &DictKey::ToReprString); + dict_key.def("__eq__", &DictKey::Equals); + dict_key.def("__hash__", + [](const DictKey &key) { return nanobind::hash(key.key()); }); + dict_key.def_prop_ro("key", &DictKey::key); + dict_key.def_prop_ro_static("__match_args__", &DictKey::MatchArgs); + dict_key.def("__getstate__", + [](DictKey &key) { return nb::make_tuple(key.key()); }); + dict_key.def("__setstate__", [](DictKey &key, const nb::tuple &state) { + if (state.size() != 1) { + throw xla::XlaRuntimeError("Malformed pickled DictKey, expected 1-tuple"); + } + new (&key) DictKey(nb::cast(state[0])); + }); + + nb::class_ get_attr_key(pytree, "GetAttrKey"); + get_attr_key.def(nb::init(), nb::arg("name")); + get_attr_key.def("__str__", &GetAttrKey::ToString); + get_attr_key.def("__repr__", &GetAttrKey::ToReprString); + get_attr_key.def("__eq__", &GetAttrKey::Equals); + get_attr_key.def("__hash__", + [](const GetAttrKey &key) { return nb::hash(key.name()); }); + get_attr_key.def_prop_ro("name", &GetAttrKey::name); + get_attr_key.def_prop_ro_static("__match_args__", &GetAttrKey::MatchArgs); + get_attr_key.def("__getstate__", + [](GetAttrKey &key) { return nb::make_tuple(key.name()); }); + get_attr_key.def("__setstate__", [](GetAttrKey &key, const nb::tuple &state) { + if (state.size() != 1) { + throw xla::XlaRuntimeError( + "Malformed pickled GetAttrKey, expected 1-tuple"); + } + new (&key) GetAttrKey(nb::str(state[0])); + }); + + nb::class_ flattened_index_key(pytree, + "FlattenedIndexKey"); + flattened_index_key.def(nb::init(), nb::arg("key")); + flattened_index_key.def("__str__", &FlattenedIndexKey::ToString); + flattened_index_key.def("__repr__", &FlattenedIndexKey::ToReprString); + flattened_index_key.def("__eq__", &FlattenedIndexKey::Equals); + flattened_index_key.def("__hash__", [](const FlattenedIndexKey &key) { + return key.key() + kFlattenedIndexKeyHashSalt; + }); + flattened_index_key.def_prop_ro("key", &FlattenedIndexKey::key); + flattened_index_key.def_prop_ro_static("__match_args__", + &FlattenedIndexKey::MatchArgs); + flattened_index_key.def("__getstate__", [](FlattenedIndexKey &key) { + return nb::make_tuple(key.key()); + }); + flattened_index_key.def( + "__setstate__", [](FlattenedIndexKey &key, const nb::tuple &state) { + if (state.size() != 1) { + throw xla::XlaRuntimeError( + "Malformed pickled FlattenedIndexKey, expected 1-tuple"); + } + new (&key) FlattenedIndexKey(nb::cast(state[0])); + }); +} + +} // namespace xla diff --git a/tests/ci_clangformat/pytree.h b/tests/ci_clangformat/pytree.h new file mode 100644 index 0000000..802b73e --- /dev/null +++ b/tests/ci_clangformat/pytree.h @@ -0,0 +1,408 @@ +/* Copyright 2019 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PYTREE_H_ +#define JAXLIB_PYTREE_H_ + +// See https://docs.jax.dev/en/latest/pytrees.html for the documentation +// about pytree. + +#include + +#include +#include +#include +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/types/span.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/pytree.pb.h" +#include "nanobind/nanobind.h" + +namespace xla { + +enum class PyTreeKind { + kLeaf, // An opaque leaf node + kNone, // None. + kTuple, // A tuple + kNamedTuple, // A collections.namedtuple + kList, // A list + kDict, // A dict + kCustom, // A custom type. + kDataclass, // A dataclass. +}; + +// Registry of custom node types. +class PyTreeRegistry { + public: + PyTreeRegistry(bool enable_none, bool enable_tuple, bool enable_namedtuple, + bool enable_list, bool enable_dict); + + PyTreeRegistry(const PyTreeRegistry &) = delete; + PyTreeRegistry(PyTreeRegistry &&) = delete; + PyTreeRegistry &operator=(const PyTreeRegistry &) = delete; + PyTreeRegistry &operator=(PyTreeRegistry &&) = delete; + + struct Registration { + PyTreeKind kind; + + // The following values are populated for custom types. + // The Python type object, used to identify the type. + nanobind::object type; + // A function with signature: object -> (iterable, aux_data) + nanobind::callable to_iterable; + // A function with signature: (aux_data, iterable) -> object + nanobind::callable from_iterable; + // A function with signature: (aux_data, iterable(keypath, leaf)) -> object + std::optional to_iterable_with_keys; + + // Helper that calls to_iterable and validates that it returns a pair + // of an iterable and an aux_data object + std::pair ToIterable( + nanobind::handle o) const; + // Helper that calls to_iterable_with_keys and validates that it returns a + // pair of an iterable of key-leaf pairs and an aux_data object. If + // to_iterable_with_keys is not available, return a dummy key for each leaf, + // similar to the current jax.tree_util.FlattenedIndexKey. + std::pair>, + nanobind::object> + ToIterableWithKeys(nanobind::handle o) const; + + // For dataclasses. + std::vector data_fields; + std::vector meta_fields; + + int tp_traverse(visitproc visit, void *arg); + }; + + // Registers a new custom type. Objects of `type` will be treated as container + // node types in PyTrees. + void Register( + nanobind::object type, nanobind::callable to_iterable, + nanobind::callable from_iterable, + std::optional to_iterable_with_keys = std::nullopt); + // Same, but for dataclasses. + void RegisterDataclass(nanobind::object type, + std::vector data_fields, + std::vector meta_fields); + + // Finds the custom type registration for `type`. Returns nullptr if none + // exists. + const Registration *Lookup(nanobind::handle type) const; + + PyTreeKind KindOfObject(nanobind::handle obj, + PyTreeRegistry::Registration const **custom) const; + + // Flattens a pytree one level, returning either a tuple of the leaves and + // the node data, or None, if the entry is a leaf. + nanobind::object FlattenOneLevel(nanobind::handle x) const; + // Similar to above but returns a key-leaf pair for each leaf. + nanobind::object FlattenOneLevelWithKeys(nanobind::handle x) const; + // Underlying implementation of FlattenOneLevel and FlattenOneLevelWithKeys. + nanobind::object FlattenOneLevelImpl(nanobind::handle x, + bool with_keys) const; + + static PyType_Slot slots_[]; + + private: + struct TypeHash { + using is_transparent = void; + size_t operator()(const nanobind::object &t) const { + return absl::HashOf(t.ptr()); + } + size_t operator()(const nanobind::handle &t) const { + return absl::HashOf(t.ptr()); + } + }; + struct TypeEq { + using is_transparent = void; + bool operator()(const nanobind::object &a, + const nanobind::object &b) const { + return a.ptr() == b.ptr(); + } + bool operator()(const nanobind::object &a, + const nanobind::handle &b) const { + return a.ptr() == b.ptr(); + } + }; + mutable nanobind::ft_mutex mu_; + absl::flat_hash_map, TypeHash, + TypeEq> + registrations_; // Guarded by mu_ + bool enable_namedtuple_; + + static int tp_traverse(PyObject *self, visitproc visit, void *arg); + static int tp_clear(PyObject *self); +}; + +class SequenceKey { + public: + explicit SequenceKey(int idx) : idx_(idx) {}; + std::string ToReprString() const; + std::string ToString() const; + bool Equals(const nanobind::object &other); + int idx() const { return idx_; } + static nanobind::tuple MatchArgs(nanobind::handle unused); + + private: + int idx_; +}; + +class DictKey { + public: + explicit DictKey(nanobind::object key) : key_(key) {}; + std::string ToReprString() const; + std::string ToString() const; + bool Equals(const nanobind::object &other); + nanobind::object key() const { return key_; } + static nanobind::tuple MatchArgs(nanobind::handle unused); + static PyType_Slot slots_[]; + + private: + nanobind::object key_; + static int tp_traverse(PyObject *self, visitproc visit, void *arg); + static int tp_clear(PyObject *self); +}; + +class GetAttrKey { + public: + explicit GetAttrKey(nanobind::str name) : name_(name) {}; + std::string ToReprString() const; + std::string ToString() const; + bool Equals(const nanobind::object &other); + nanobind::str name() const { return name_; } + static nanobind::tuple MatchArgs(nanobind::handle unused); + + private: + nanobind::str name_; +}; + +class FlattenedIndexKey { + public: + explicit FlattenedIndexKey(int key) : key_(key) {}; + std::string ToReprString() const; + std::string ToString() const; + bool Equals(const nanobind::object &other); + int key() const { return key_; } + static nanobind::tuple MatchArgs(nanobind::handle unused); + + private: + int key_; +}; + +// A PyTreeDef describes the tree structure of a PyTree. A PyTree is a tree of +// Python values, where the interior nodes are tuples, lists, dictionaries, or +// user-defined containers, and the leaves are other objects. +class PyTreeDef { + public: + // Unowned registry: the registry must remain live at least as long as the + // PyTreeDef. It is the caller's responsibility to enforce this. + explicit PyTreeDef(PyTreeRegistry *registry) : registry_(registry) {} + + explicit PyTreeDef(nb_class_ptr registry) + : registry_(registry.get()), registry_ref_(std::move(registry)) {} + + // Flattens a Pytree into a list of leaves and a PyTreeDef. + // Returns references to the flattened objects, which might be temporary + // objects in the case of custom pytype handlers. + static std::pair, nb_class_ptr> + Flatten(nanobind::handle x, nb_class_ptr registry, + std::optional leaf_predicate = std::nullopt); + + // Flattens a Pytree into a list of `leaves` and a PyTreeDef (this). + // `leaves` owns references to the flattened objects, which might be + // temporary objects in the case of custom pytype handlers. + void Flatten(nanobind::handle handle, std::vector &leaves, + std::optional leaf_predicate = std::nullopt); + void Flatten(nanobind::handle handle, + absl::InlinedVector &leaves, + std::optional leaf_predicate = std::nullopt); + void Flatten(nanobind::handle handle, nanobind::list &leaves, + std::optional leaf_predicate = std::nullopt); + + void FlattenWithPath( + nanobind::handle handle, nanobind::list &leaves, + std::optional leaf_predicate = std::nullopt); + + // Tests whether the given list is a flat list of leaves. + static bool AllLeaves(PyTreeRegistry *registry, const nanobind::iterable &x); + + // Flattens a Pytree up to this PyTreeDef. 'this' must be a tree prefix of + // the tree-structure of 'x'. For example, if we flatten a value + // [(1, (2, 3)), {"foo": 4}] with a treedef [(*, *), *], the result is the + // list of leaves [1, (2, 3), {"foo": 4}]. + nanobind::list FlattenUpTo(nanobind::handle x) const; + + // Returns an unflattened PyTree given an iterable of leaves and a PyTreeDef. + nanobind::object Unflatten(nanobind::iterable leaves) const; + nanobind::object Unflatten(absl::Span leaves) const; + + // Composes two PyTreeDefs, replacing the leaves of this tree with copies of + // `inner`. The returned PyTreeDef holds a reference to its registry. + nb_class_ptr Compose(const PyTreeDef &inner) const; + + // Makes a Tuple PyTreeDef out of a vector of PyTreeDefs. + static nb_class_ptr Tuple(nb_class_ptr registry, + nanobind::list defs); + + // The returned PyTreeDefs hold a reference to the registry. + std::vector> Children() const; + + // Maps a function over a PyTree structure, applying f_leaf to each leaf, and + // f_node(node, node_data) to each container node. + nanobind::object Walk(const nanobind::callable &f_node, + nanobind::handle f_leaf, + nanobind::iterable leaves) const; + + // Given a tree of iterables with the same node/leaf structure as this PyTree, + // build the corresponding PyTree. + // TODO(phawkins): use flattening everywhere instead and delete this method. + nanobind::object FromIterableTree(nanobind::handle xs) const; + + int num_leaves() const { + if (traversal_.empty()) { + return 0; + } + return traversal_.back().num_leaves; + } + + int num_nodes() const { return traversal_.size(); } + + PyTreeRegistry *registry() const { return registry_; } + + size_t Hash() const; + + bool operator==(const PyTreeDef &other) const; + bool operator!=(const PyTreeDef &other) const { return !(*this == other); } + + std::string ToString() const; + + // Transforms the PyTreeDef into a pickleable object. Used to implement + // `PyTreeDef.__getstate__`. + nanobind::object ToPickle() const; + + // Transforms the object returned by `ToPickleable()` back to PyTreeDef. Used + // to implement `PyTreeDef.__setstate__`. + void FromPickle(nanobind::object pickleable); + + void SerializeTo(jax::PyTreeDefProto &result) const; + + static nb_class_ptr DeserializeFrom( + nb_class_ptr registry, const jax::PyTreeDefProto &input); + + std::optional> GetNodeData() + const; + + static nb_class_ptr MakeFromNodeDataAndChildren( + nb_class_ptr registry, + std::optional> node_data, + nanobind::iterable children); + + static PyType_Slot slots_[]; + + private: + void SetNumLeavesAndNumNodes(); + + struct Node { + PyTreeKind kind = PyTreeKind::kLeaf; + + // Arity for non-kLeaf types. + int arity = 0; + + // Kind-specific auxiliary data. For a kNamedTuple, contains the tuple type + // object. For a kDict, use `sorted_dict_keys` field below. For a kCustom + // type, contains the auxiliary data returned by the `to_iterable` function. + nanobind::object node_data; + + // Kind-specific auxiliary data specialized for kDict. Use a c++ vector + // to hold the sorted dict keys instead of a py::list to avoid creating + // a new python list object when flattening kDict. For deeply nested dict, + // using c++ vector instead of py::list avoids creating too many python + // objects that make python gc sweep slow. + std::vector sorted_dict_keys; + + // Custom type registration. Must be null for non-custom types. + const PyTreeRegistry::Registration *custom = nullptr; + + // Number of leaf nodes in the subtree rooted at this node. + int num_leaves = 0; + + // Number of leaf and interior nodes in the subtree rooted at this node. + int num_nodes = 0; + + int tp_traverse(visitproc visit, void *arg) const; + }; + template + friend H AbslHashValue(H h, const Node &n); + + template + friend H AbslHashValue(H h, const PyTreeDef &t); + + // Helper that manufactures an instance of a node given its children. + static nanobind::object MakeNode(const Node &node, + absl::Span children); + + // Recursive helper used to implement FromIterableTree() + nanobind::object FromIterableTreeHelper( + nanobind::handle xs, + absl::InlinedVector::const_reverse_iterator *it) + const; + + template + void FlattenImpl(nanobind::handle handle, T &leaves, + const std::optional &leaf_predicate, + std::optional> &keypath); + + template + nanobind::object UnflattenImpl(T leaves) const; + + static int tp_traverse(PyObject *self, visitproc visit, void *arg); + static int tp_clear(PyObject *self); + + // Pytree registry. Not owned. + PyTreeRegistry *registry_; + // If this class holds a reference to `registry`, it is held by + // `registry_ref_`. + nb_class_ptr registry_ref_; + + // Nodes, in a post-order traversal. We use an ordered traversal to minimize + // allocations, and post-order corresponds to the order we need to rebuild the + // tree structure. + absl::InlinedVector traversal_; +}; + +template +H AbslHashValue(H h, const PyTreeDef::Node &n) { + h = H::combine(std::move(h), n.kind, n.arity, n.custom); + return h; +} + +template +H AbslHashValue(H h, const PyTreeDef &t) { + h = H::combine(std::move(h), t.traversal_); + return h; +} + +void BuildPytreeSubmodule(nanobind::module_ &m); + +} // namespace xla + +#endif // JAXLIB_PYTREE_H_ diff --git a/tests/ci_clangformat/sdy.cc b/tests/ci_clangformat/sdy.cc new file mode 100644 index 0000000..c179919 --- /dev/null +++ b/tests/ci_clangformat/sdy.cc @@ -0,0 +1,140 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/sdy.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mhlo/transforms/passes.h" +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/tuple.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "shardy/dialect/sdy/ir/dialect.h" +#include "shardy/dialect/sdy/ir/utils.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/status_casters.h" +#include "xla/service/spmd/shardy/constants.h" +#include "xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.h" +#include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h" +#include "xla/service/spmd/shardy/utils.h" +#include "xla/tsl/framework/mlir/status_scoped_diagnostic_handler.h" +#include "llvm/Support/raw_ostream.h" + +namespace nb = nanobind; + +namespace xla { + +namespace { + +absl::StatusOr SerializeUsingBytecode(mlir::ModuleOp module) { + std::string bytecode; + llvm::raw_string_ostream os(bytecode); + mlir::BytecodeWriterConfig config; + if (mlir::failed(mlir::writeBytecodeToFile(module, os, config))) { + return absl::InvalidArgumentError("mlir::writeBytecodeToFile failed"); + } + return bytecode; +} + +} // namespace + +void BuildSdySubmodule(nb::module_ &m) { + nb::module_ mlir_module = m.def_submodule("sdy", "Shardy/XLA integration"); + + mlir_module + // TODO(b/707574930): define a C API for the XLA pipelines. + .def( + "sdy_round_trip_export_pipeline", + [](const nb::bytes &bytecode) -> nb::bytes { + mlir::MLIRContext context; + mlir::OwningOpRef module = + xla::ValueOrThrow(ParseMlirModuleString( + absl::string_view(bytecode.c_str(), bytecode.size()), + context)); + mlir::PassManager pm(&context); + sdy::addSdyRoundTripExportPipeline(pm); + tsl::StatusScopedDiagnosticHandler diagnosticHandler(&context); + ThrowIfError(diagnosticHandler.consumeStatus(pm.run(module.get()))); + std::string module_str = + xla::ValueOrThrow(SerializeUsingBytecode(module.get())); + return nb::bytes(module_str.data(), module_str.size()); + }, + nb::arg("module")) + .def( + "sdy_round_trip_import_shardings", + [](const nb::bytes &bytecode) -> nb::bytes { + mlir::MLIRContext context; + mlir::OwningOpRef module = + xla::ValueOrThrow(ParseMlirModuleString( + absl::string_view(bytecode.c_str(), bytecode.size()), + context)); + mlir::PassManager pm(&context); + pm.addPass(xla::sdy::createSdyRoundTripImportShardyAttrsPass()); + tsl::StatusScopedDiagnosticHandler diagnosticHandler(&context); + ThrowIfError(diagnosticHandler.consumeStatus(pm.run(module.get()))); + std::string module_str = + xla::ValueOrThrow(SerializeUsingBytecode(module.get())); + return nb::bytes(module_str.data(), module_str.size()); + }, + nb::arg("module")) + .def("lowered_with_shardy", + [](const nb::bytes &bytecode) -> bool { + mlir::MLIRContext context; + mlir::OwningOpRef module = + xla::ValueOrThrow(ParseMlirModuleString( + absl::string_view(bytecode.c_str(), bytecode.size()), + context)); + return mlir::sdy::getMeshAttr(module.get(), "mesh") || + sdy::tryGetFrontendAttr( + module.get(), sdy::kMeshesRoundTripAttr) + .has_value(); + }) + // TODO(bartchr): delete this and all uses of it once I have JAX export + // support multiple meshes. + .def("get_mesh", [](const nb::bytes &bytecode) -> nb::list { + mlir::MLIRContext context; + mlir::OwningOpRef module = + xla::ValueOrThrow(ParseMlirModuleString( + absl::string_view(bytecode.c_str(), bytecode.size()), context)); + auto mesh_attr = mlir::sdy::getMeshAttr(module.get(), "mesh"); + if (!mesh_attr) { + return {}; + } + nb::list mesh_shape; + for (auto axis : mesh_attr.getAxes()) { + mesh_shape.append( + nb::make_tuple(axis.getName().str(), axis.getSize())); + } + return mesh_shape; + }); +} + +} // namespace xla diff --git a/tests/ci_clangformat/sdy.h b/tests/ci_clangformat/sdy.h new file mode 100644 index 0000000..347b172 --- /dev/null +++ b/tests/ci_clangformat/sdy.h @@ -0,0 +1,28 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_SDY_H_ +#define JAXLIB_SDY_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace xla { + +void BuildSdySubmodule(nanobind::module_ &m); + +} // namespace xla + +#endif // JAXLIB_SDY_H_ diff --git a/tests/ci_clangformat/sharded_device_array.h b/tests/ci_clangformat/sharded_device_array.h new file mode 100644 index 0000000..cef80d1 --- /dev/null +++ b/tests/ci_clangformat/sharded_device_array.h @@ -0,0 +1,216 @@ +/* Copyright 2021 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_SHARDED_DEVICE_ARRAY_H_ +#define JAXLIB_SHARDED_DEVICE_ARRAY_H_ + +#include +#include +#include + +#include "nanobind/nanobind.h" +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "xla/python/types.h" + +// TODO(jblespiau): The current implementation moves the Python logic to C++, +// as a preliminary step to executing the `pmap` execution path from C++. +// It implements the current Python behavior (thus, it may not be optimal, and +// we will be able to modify it later). + +namespace jax { + +// High level introduction. +// +// pmap and other parallel computation functions distribute some computation on +// several devices. On December 2020, the devices mesh (i.e. N-dimensional array +// of devices on which we map the computation) is defined by the user. +// +// We describe how to shard the inputs, and how to map it to the mesh of devices +// using `ShardingSpec`. It's mainly based on 2 components: +// - `sharding`, which specifies how to shard the inputs. +// - `mesh_mapping`, which specifies how to map shards to devices. +// +// The 3 following structs define how to shard one dimension of an ndarry. +// +// `NoSharding` (`None` in Python) means no sharding. +struct NoSharding { + bool operator==(const NoSharding &other) const { return true; } + bool operator!=(const NoSharding &other) const { return false; } +}; + +template +H AbslHashValue(H h, const NoSharding &key) { + return h; +} + +// `Chunked` means that the dimension is split into np.prod(chunks) chunks +// and the split dimension itself is preserved inside the map. +// Those chunks are distributed over `len(chunks)` ShardedAxes axes +// (major-to-minor). +// For example, for a tensor `t` of shape [N] sharded using [Chunked([p])] (with +// p dividing N, let S = N // p) the tensor will be split into p chunks of +// shape [S], such sharded_t[k] = t[k * S: (k+1)*S] (left included, right +// excluded) for k in {0, ... p-1}. +struct Chunked { + public: + explicit Chunked(std::vector chunks_) : chunks(std::move(chunks_)) {} + // The number of chunks per axis. + std::vector chunks; + + bool operator==(const Chunked &other) const { return chunks == other.chunks; } + bool operator!=(const Chunked &other) const { return chunks != other.chunks; } +}; + +template +H AbslHashValue(H h, const Chunked &key) { + h = H::combine(std::move(h), key.chunks); + return h; +} + +// `Unstacked` means that the dimension is split into chunks of size 1, and +// doesn't appear inside the map. `size` is always the dimension size. +// For example, a Tensor t of shape [N] will be sharded into N tensors of shape +// [], when using `Unstacked(N)`. +struct Unstacked { + public: + explicit Unstacked(int sz) : size(sz) {} + int size; + + bool operator==(const Unstacked &other) const { return size == other.size; } + bool operator!=(const Unstacked &other) const { return size != other.size; } +}; + +template +H AbslHashValue(H h, const Unstacked &key) { + h = H::combine(std::move(h), key.size); + return h; +} + +using AvalDimSharding = std::variant; + +// Assigns sharded axes to mesh dimensions. +// +// The devices will be for each dimension which has a sharded `AvalDimSharding` +// When no axis is assigned, the data is replicated. +// As indices are 0-indexed, `ShardedAxis(1)` refers to the second actually +// sharded axis (i.e. counting as if the None dimensions of sharding were +// filtered out). +// For example, given the sharding `[Unstacked(n), None, Chunked(m)]`, an entry +// of `ShardedAxis(1)` refers to the `Chunked(m)` axis, not the `None`. + +struct ShardedAxis { + int axis; + bool operator==(const ShardedAxis &other) const { return axis == other.axis; } + bool operator!=(const ShardedAxis &other) const { return axis != other.axis; } +}; + +template +H AbslHashValue(H h, const ShardedAxis &key) { + h = H::combine(std::move(h), key.axis); + return h; +} + +struct Replicated { + int replicas; + bool operator==(const Replicated &other) const { + return replicas == other.replicas; + } + bool operator!=(const Replicated &other) const { + return replicas != other.replicas; + } +}; + +template +H AbslHashValue(H h, const Replicated &key) { + h = H::combine(std::move(h), key.replicas); + return h; +} + +using MeshDimAssignment = std::variant; + +// Describes how each axis is sharded (if it is), and how it's mapped to the +// devices mesh. See Jax pxla.py for the documentation. +// +// ShardingSpec is shared across pmap, pjit and xpmap. For pmap, an input +// `sharding` is composed of `NoSharding` and at most one `Unstacked`. +// If `axis_size=None`, at least one the inputs has a dimension associated to +// `Unstacked`. +// +// Examples: +// +// 1. For pmap, with a tensor of shape [8, 2, 2], to unstack along the first +// dimension into [8] devices: +// +// sharding = [Unstacked(8), NoSharding, NoSharding] +// mesh_mapping = [ShardedAxis(0)] +// +// 2. With an input array of shape [6], that we want to chunk into [2, 3] +// Assuming a device mesh [3, 4, 2] of devices, we will have: +// +// sharding = [Chunked([2, 3])] +// mesh_mapping = [ShardedAxis(1), Replicated, ShardedAxis(0)] +// +// In particular, in the above example, the ShardedAxis refers to indices +// of the sharded shape [2, 3]. (only the `Chunked` sharding can produce more +// than one dimension). +class ShardingSpec { + public: + ShardingSpec(std::vector sharding, + std::vector mesh_mapping) + : sharding_(std::move(sharding)), + mesh_mapping_(std::move(mesh_mapping)) {} + ShardingSpec(nanobind::iterable py_sharding, + nanobind::iterable py_mesh_mapping) + : sharding_(xla::IterableToVector(py_sharding)), + mesh_mapping_( + xla::IterableToVector(py_mesh_mapping)) {} + + const std::vector &GetSharding() const { return sharding_; } + const std::vector &GetMeshMapping() const { + return mesh_mapping_; + } + + bool operator==(const ShardingSpec &other) const { + return sharding_ == other.sharding_ && mesh_mapping_ == other.mesh_mapping_; + } + + bool operator!=(const ShardingSpec &other) const { return !(*this == other); } + + template + friend H AbslHashValue(H h, const ShardingSpec &key); + + private: + // `sharding` specifies how the array is supposed to get partitioned into + // chunks. Its length matches the rank of the array. See the docstring + // of `AvalDimSharding` for the supported partitioning schemes. + std::vector sharding_; + // `mesh_mapping` describes an assignments of the array chunks created by + // `sharding` to a logical device mesh. The length of the tuple is equal to + // the rank of the mesh. Each mesh dimension can either get partitions of + // data varying along one of the sharded dimensions, or the data can be + // replicated. + std::vector mesh_mapping_; +}; + +template +H AbslHashValue(H h, const ShardingSpec &key) { + h = H::combine(std::move(h), key.sharding_); + h = H::combine(std::move(h), key.mesh_mapping_); + return h; +} + +} // namespace jax + +#endif // JAXLIB_SHARDED_DEVICE_ARRAY_H_ diff --git a/tests/ci_clangformat/sharding.cc b/tests/ci_clangformat/sharding.cc new file mode 100644 index 0000000..671f794 --- /dev/null +++ b/tests/ci_clangformat/sharding.cc @@ -0,0 +1,396 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/sharding.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/hash/hash.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_device_list.h" +#include "jaxlib/sharded_device_array.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/safe_static_init.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/xla_data.pb.h" + +namespace jax { + +namespace nb = nanobind; + +// Gets `jax::PyDeviceList` from a JAX Sharding. +absl::StatusOr> GetPyDeviceList( + nb::handle sharding) { + if (sharding.type().is(jax::NamedSharding::type())) { + TF_ASSIGN_OR_RETURN( + auto ns_device_list, + nb::cast(sharding)->internal_device_list()); + return ns_device_list; + } else if (sharding.type().is(jax::SingleDeviceSharding::type())) { + return nb::cast(sharding) + ->internal_device_list(); + } else if (sharding.type().is(jax::PmapSharding::type())) { + return nb::cast(sharding) + ->internal_device_list(); + } else if (sharding.type().is(jax::GSPMDSharding::type())) { + return nb::cast(sharding) + ->internal_device_list(); + } else { + return nb::cast>( + sharding.attr("_internal_device_list")); + } +} + +nb::object CheckAndCanonicalizeMemoryKind( + nb::object memory_kind, + const xla::nb_class_ptr &device_list) { + if (!memory_kind.is_none()) { + // If memory kind is not None, check if it's supported by the devices + // mentioned in the Sharding. + auto supported_memory_kinds = PyDeviceList::MemoryKinds(device_list); + if (!supported_memory_kinds.ok()) { + supported_memory_kinds = nb::tuple(); + } + for (nb::handle supported_memory_kind : *supported_memory_kinds) { + if (supported_memory_kind.equal(memory_kind)) { + return memory_kind; + } + } + auto addressable_device_list = + PyDeviceList::AddressableDeviceList(device_list); + if (addressable_device_list->Len() == 0) { + // If the device list is not addressable, we can't check if the memory + // kind is supported, so we assume it is. + return memory_kind; + } + nb::object device_kind = + addressable_device_list->GetItem(0).attr("device_kind"); + absl::string_view device_kind_str = + nb::cast(device_kind); + auto py_str_formatter = [](std::string *out, nb::handle h) { + *out += nb::cast(nb::str(h)); + }; + throw nb::value_error( + absl::StrCat( + "Could not find memory addressable by device ", device_kind_str, + ". Device ", device_kind_str, + " can address the following memory kinds: ", + absl::StrJoin(*supported_memory_kinds, ", ", py_str_formatter), + ". Got memory kind: ", nb::cast(memory_kind)) + .c_str()); + } + // If memory kind is None, canonicalize to default memory. + absl::StatusOr default_memory_kind = + PyDeviceList::DefaultMemoryKind(device_list); + if (!default_memory_kind.ok()) { + return nb::none(); + } + return *std::move(default_memory_kind); +} + +int Sharding::SafeNumDevices(nb::handle sharding) { + const jax::Sharding *cpp_sharding; + if (nb::try_cast(sharding, cpp_sharding)) { + if (cpp_sharding->num_devices_.has_value()) { + return (*cpp_sharding->num_devices_); + } + } + nb::set device_set = sharding.attr("device_set"); + return device_set.size(); +} + +size_t ShardingHash(nb::handle sharding) { + auto type = sharding.type(); + + if (type.is(NamedSharding::type())) { + const auto *named_sharding = nb::inst_ptr(sharding); + return absl::Hash()(named_sharding->mesh().ptr()); + } + + if (type.is(GSPMDSharding::type())) { + auto *gspmd_sharding = nb::inst_ptr(sharding); + return gspmd_sharding->Hash(); + } + + if (type.is(SingleDeviceSharding::type())) { + auto *single_device_sharding = nb::inst_ptr(sharding); + return absl::Hash()(single_device_sharding->device().ptr()); + } + + return nb::hash(sharding); +} + +bool ShardingEqual(nb::handle a, nb::handle b) { + if (a.ptr() == b.ptr()) return true; + + auto a_type = a.type(); + auto b_type = b.type(); + + if (!a_type.is(b_type)) return false; + + if (a_type.is(NamedSharding::type())) { + auto *a_named_sharding = nb::inst_ptr(a); + auto *b_named_sharding = nb::inst_ptr(b); + + return a_named_sharding->mesh().ptr() == b_named_sharding->mesh().ptr() && + a_named_sharding->spec().equal(b_named_sharding->spec()) && + a_named_sharding->memory_kind().equal( + b_named_sharding->memory_kind()) && + a_named_sharding->logical_device_ids().equal( + b_named_sharding->logical_device_ids()); + } + + if (a_type.is(GSPMDSharding::type())) { + auto *a_gspmd_sharding = nb::inst_ptr(a); + auto *b_gspmd_sharding = nb::inst_ptr(b); + + return a_gspmd_sharding == b_gspmd_sharding; + } + + if (a_type.is(SingleDeviceSharding::type())) { + auto *a_single_device_sharding = + nb::inst_ptr(a); + auto *b_single_device_sharding = + nb::inst_ptr(b); + + return a_single_device_sharding->device().ptr() == + b_single_device_sharding->device().ptr() && + a_single_device_sharding->memory_kind().equal( + b_single_device_sharding->memory_kind()); + } + + return a.equal(b); +} + +// This list is to check for valid memory kinds when an AbstractMesh is passed +// to NamedSharding. +static const std::array valid_memory_kinds = { + "device", + "pinned_host", + "unpinned_host", +}; + +NamedSharding::NamedSharding(nb::object mesh, nb::object spec, + nb::object memory_kind, + nb::object logical_device_ids) + : Sharding(/*num_devices=*/[&mesh]() { + return nb::cast(mesh.attr("size")); + }()), + mesh_(std::move(mesh)), + spec_(std::move(spec)), + memory_kind_(std::move(memory_kind)), + logical_device_ids_(std::move(logical_device_ids)) { + if (spec_.is_none()) { + throw nb::type_error( + "Unexpected None passed as spec for NamedSharding. Did you mean P()?"); + } + nb::object idl = nb::object(mesh_.attr("_internal_device_list")); + if (idl.is_none()) { + internal_device_list_ = std::nullopt; + } else { + internal_device_list_ = nb::cast>(idl); + } + if (internal_device_list_) { + memory_kind_ = + CheckAndCanonicalizeMemoryKind(memory_kind_, *internal_device_list_); + } else { + if (!memory_kind_.is_none() && + (std::find(valid_memory_kinds.begin(), valid_memory_kinds.end(), + nb::cast(memory_kind_)) == + valid_memory_kinds.end())) { + throw nb::value_error( + absl::StrCat("Got invalid memory kind: ", + nb::cast(memory_kind_), + ". Valid memory kinds are: ", + absl::StrJoin(valid_memory_kinds, ", ")) + .c_str()); + } + } + + // TODO(phawkins): this leaks a reference to the check_pspec function. + // A better way to fix this would be to move PartitionSpec and this check into + // C++. + auto init_fn = []() { + nb::module_ si = nb::module_::import_("jax._src.named_sharding"); + return std::make_unique(si.attr("check_pspec")); + }; + nb::object &check_pspec = xla::SafeStaticInit(init_fn); + check_pspec(mesh_, spec_); +} + +/*static*/ PyObject *NamedSharding::type_ = nullptr; + +/*static*/ void NamedSharding::InitializeType() { + // Intentionally leaks a reference. + type_ = nanobind::type().inc_ref().ptr(); +} + +SingleDeviceSharding::SingleDeviceSharding(nb::object device, + nb::object memory_kind) + : Sharding(/*num_devices=*/1), + device_(device), + memory_kind_(std::move(memory_kind)), + internal_device_list_( + xla::make_nb_class(nb::make_tuple(std::move(device)))) { + memory_kind_ = + CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); +} + +/*static*/ PyObject *SingleDeviceSharding::type_ = nullptr; + +/*static*/ void SingleDeviceSharding::InitializeType() { + // Intentionally leaks a reference. + type_ = nanobind::type().inc_ref().ptr(); +} + +SingleDeviceSharding::SingleDeviceSharding( + xla::nb_class_ptr client, + xla::ifrt::DeviceListRef device_list, nb::object memory_kind) + : Sharding(/*num_devices=*/1), + device_(client->GetPyDevice(device_list->devices().front())), + memory_kind_(std::move(memory_kind)), + internal_device_list_(xla::make_nb_class( + std::move(client), std::move(device_list))) { + memory_kind_ = + CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); +} + +PmapSharding::PmapSharding(xla::nb_numpy_ndarray devices, + ShardingSpec sharding_spec) + : Sharding(/*num_devices=*/devices.size()), + devices_(std::move(devices)), + sharding_spec_(std::move(sharding_spec)) { + nb::object flat_devices = devices_.attr("flat"); + internal_device_list_ = + xla::make_nb_class(nb::tuple(flat_devices)); +} + +/*static*/ PyObject *PmapSharding::type_ = nullptr; + +// /*static*/ nanobind::handle PmapSharding::type() { return type_; } + +/*static*/ void PmapSharding::InitializeType() { + // Intentionally leaks a reference. + type_ = nanobind::type().inc_ref().ptr(); +} + +GSPMDSharding::GSPMDSharding(nb::sequence devices, xla::HloSharding op_sharding, + nb::object memory_kind, nb::object device_list) + : Sharding(/*num_devices=*/nb::len(devices.ptr())), + devices_(nb::tuple(devices)), + hlo_sharding_(std::move(op_sharding)), + memory_kind_(std::move(memory_kind)) { + if (device_list.is_none()) { + internal_device_list_ = xla::make_nb_class(devices_); + } else { + internal_device_list_ = + nb::cast>(std::move(device_list)); + } + // This checks in python if the memory kind is correct for the given + // devices. Currently in python this check is optimized but we want to + // move that check to C++ after which we can remove this call. + CHECK(devices_.size() != 0) + << "Devices given to GSPMDSharding must not be empty"; + memory_kind_ = + CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); +} + +/*static*/ PyObject *GSPMDSharding::type_ = nullptr; + +/*static*/ void GSPMDSharding::InitializeType() { + // Intentionally leaks a reference. + type_ = nanobind::type().inc_ref().ptr(); +} + +void RegisterSharding(nb::module_ &m) { + nb::class_(m, "Sharding").def(nb::init<>()); + + nb::class_(m, "NamedSharding", nb::dynamic_attr()) + .def(nb::init(), + nb::arg("mesh"), nb::arg("spec").none(), + nb::arg("memory_kind").none() = nb::none(), + nb::arg("_logical_device_ids").none() = nb::none()) + .def_prop_ro("mesh", &NamedSharding::mesh) + .def_prop_ro("spec", &NamedSharding::spec) + .def_prop_ro("_memory_kind", &NamedSharding::memory_kind) + .def_prop_ro("_logical_device_ids", &NamedSharding::logical_device_ids) + .def_prop_ro("_internal_device_list", [](const NamedSharding &s) { + return xla::ValueOrThrow(s.internal_device_list()); + }); + NamedSharding::InitializeType(); + + nb::class_(m, "SingleDeviceSharding", + nb::dynamic_attr()) + .def(nb::init(), nb::arg("device"), + nb::arg("memory_kind").none() = nb::none()) + .def_prop_ro("_device", &SingleDeviceSharding::device) + .def_prop_ro("_memory_kind", &SingleDeviceSharding::memory_kind) + .def_prop_ro("_internal_device_list", + &SingleDeviceSharding::internal_device_list); + SingleDeviceSharding::InitializeType(); + + nb::class_(m, "PmapSharding", nb::dynamic_attr()) + .def( + "__init__", + [](PmapSharding *self, nb::object devices, + ShardingSpec sharding_spec) { + new (self) PmapSharding(xla::nb_numpy_ndarray::ensure(devices), + std::move(sharding_spec)); + }, + nb::arg("devices"), nb::arg("sharding_spec")) + .def_prop_ro("devices", &PmapSharding::devices) + .def_prop_ro("sharding_spec", &PmapSharding::sharding_spec) + .def_prop_ro("_internal_device_list", + &PmapSharding::internal_device_list); + PmapSharding::InitializeType(); + + nb::class_(m, "GSPMDSharding", nb::dynamic_attr()) + .def(nb::init(), + nb::arg("devices"), nb::arg("op_sharding"), + nb::arg("memory_kind").none() = nb::none(), + nb::arg("_device_list").none() = nb::none()) + .def(nb::init(), + nb::arg("devices"), nb::arg("op_sharding"), + nb::arg("memory_kind").none() = nb::none(), + nb::arg("_device_list").none() = nb::none()) + .def_prop_ro("_devices", &GSPMDSharding::devices) + .def_prop_ro("_hlo_sharding", &GSPMDSharding::hlo_sharding) + .def_prop_ro("_memory_kind", &GSPMDSharding::memory_kind) + .def_prop_ro("_internal_device_list", + &GSPMDSharding::internal_device_list); + GSPMDSharding::InitializeType(); +} + +} // namespace jax diff --git a/tests/ci_clangformat/sharding.h b/tests/ci_clangformat/sharding.h new file mode 100644 index 0000000..bf1b2c7 --- /dev/null +++ b/tests/ci_clangformat/sharding.h @@ -0,0 +1,241 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_SHARDING_H_ +#define JAXLIB_SHARDING_H_ + +#include + +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/hash/hash.h" +#include "absl/status/statusor.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_device_list.h" +#include "jaxlib/sharded_device_array.h" +#include "nanobind/nanobind.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/nb_numpy.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace jax { + +class Sharding { + public: + Sharding() = default; + + // This constructor is used in the fast path to retrieve the number of devices + // without falling back to python. This is only used in the cpp path. + explicit Sharding(int num_devices) : num_devices_(num_devices) {} + + virtual ~Sharding() = default; + + static int SafeNumDevices(nanobind::handle sharding); + + private: + std::optional num_devices_; +}; + +// Gets `jax::PyDeviceList` from a JAX Sharding. +absl::StatusOr> GetPyDeviceList( + nanobind::handle sharding); + +// Checks if the memory kind is valid, and canonicalizes the +// memory kind to default memory on backends that support memories. +nanobind::object CheckAndCanonicalizeMemoryKind( + nanobind::object memory_kind, + const xla::nb_class_ptr &device_list); + +// Returns a hash that may sometimes return different hashes for equal values. +// It is not a correct implementation of `__hash__` in python, but it's fine +// for jit/pjit dispatch since it only causes spurious cache misses. +size_t ShardingHash(nanobind::handle sharding); + +bool ShardingEqual(nanobind::handle a, nanobind::handle b); + +class NamedSharding : public Sharding { + public: + NamedSharding(nanobind::object mesh, nanobind::object spec, + nanobind::object memory_kind, + nanobind::object logical_device_ids); + + const nanobind::object &mesh() const { return mesh_; } + const nanobind::object &spec() const { return spec_; } + const nanobind::object &memory_kind() const { return memory_kind_; } + const nanobind::object &logical_device_ids() const { + return logical_device_ids_; + } + + static nanobind::handle type() { return type_; } + static void InitializeType(); + + absl::StatusOr> internal_device_list() const { + if (internal_device_list_) { + return *internal_device_list_; + } + return xla::InvalidArgument( + "internal_device_list is not implemented for " + "`jax.sharding.AbstractMesh`"); + } + + private: + nanobind::object mesh_; + nanobind::object spec_; + nanobind::object memory_kind_; + nanobind::object logical_device_ids_; + std::optional> internal_device_list_; + static PyObject *type_; +}; + +class SingleDeviceSharding : public Sharding { + public: + explicit SingleDeviceSharding( + nanobind::object device, nanobind::object memory_kind = nanobind::none()); + + // Used only in C++ to accelerate `PyArray::MakeFromSingleDeviceArray()`. + SingleDeviceSharding(xla::nb_class_ptr client, + xla::ifrt::DeviceListRef device_list, + nanobind::object memory_kind); + + const nanobind::object &device() const { return device_; } + const nanobind::object &memory_kind() const { return memory_kind_; } + + static nanobind::handle type() { return type_; } + static void InitializeType(); + + xla::nb_class_ptr internal_device_list() const { + return internal_device_list_; + } + + private: + nanobind::object device_; + nanobind::object memory_kind_; + xla::nb_class_ptr internal_device_list_; + + static PyObject *type_; +}; + +// The C++ implementation of jax.PmapSharding in python. It contains a few key +// data members and methods that are performance-critical. +class PmapSharding : public Sharding { + public: + PmapSharding(xla::nb_numpy_ndarray devices, ShardingSpec sharding_spec); + + ~PmapSharding() override = default; + + xla::nb_numpy_ndarray devices() const { return devices_; } + + const ShardingSpec &sharding_spec() const { return sharding_spec_; } + + static nanobind::handle type() { return type_; } + static void InitializeType(); + + xla::nb_class_ptr internal_device_list() const { + return internal_device_list_; + } + + private: + xla::nb_numpy_ndarray devices_; + ShardingSpec sharding_spec_; + xla::nb_class_ptr internal_device_list_; + static PyObject *type_; +}; + +class GSPMDSharding : public Sharding { + public: + GSPMDSharding(nanobind::sequence devices, xla::OpSharding op_sharding, + nanobind::object memory_kind, nanobind::object device_list) + : GSPMDSharding( + std::move(devices), + xla::ValueOrThrow(xla::HloSharding::FromProto(op_sharding)), + std::move(memory_kind), std::move(device_list)) {} + + GSPMDSharding(nanobind::sequence devices, xla::HloSharding op_sharding, + nanobind::object memory_kind, nanobind::object device_list); + + const nanobind::tuple &devices() const { return devices_; } + const nanobind::object &memory_kind() const { return memory_kind_; } + + size_t Hash() { + if (!hash_.has_value()) { + hash_ = CalculateHash(); + } + return *hash_; + } + + static nanobind::handle type() { return type_; } + static void InitializeType(); + + const xla::HloSharding &hlo_sharding() const { return hlo_sharding_; } + + bool operator==(const GSPMDSharding &other) const { + return AreOpShardingsEqual(*this, other) && + this->devices().equal(other.devices()) && + this->memory_kind().equal(other.memory_kind()); + } + + xla::nb_class_ptr internal_device_list() const { + return internal_device_list_; + } + + private: + size_t CalculateHash() const { + // We only hash `hlo_sharding_` here for performance. + return absl::Hash()(hlo_sharding_); + } + + static bool AreOpShardingsEqual(const GSPMDSharding &a, + const GSPMDSharding &b) { + // If the OpSharding object is the same, return true + if (&a.hlo_sharding() == &b.hlo_sharding()) { + return true; + } + // If both OpShardings are replicated, return true + if (a.IsOpShardingReplicated() && b.IsOpShardingReplicated()) { + return true; + } + return a.hlo_sharding() == b.hlo_sharding(); + } + + bool IsOpShardingReplicated() const { + // For JAX, shardings with 1 device are considered as replicated in its + // semantics so that downstream things continue to work. + if (hlo_sharding_.tile_assignment().num_elements() == 1) { + return true; + } + return hlo_sharding().IsReplicated(); + } + + nanobind::tuple devices_; + xla::HloSharding hlo_sharding_; + nanobind::object memory_kind_; + std::optional hash_; + xla::nb_class_ptr internal_device_list_; + + static PyObject *type_; +}; + +void RegisterSharding(nanobind::module_ &m); + +} // namespace jax + +#endif // JAXLIB_SHARDING_H_ diff --git a/tests/ci_clangformat/to_ifrt_sharding.cc b/tests/ci_clangformat/to_ifrt_sharding.cc new file mode 100644 index 0000000..31e4047 --- /dev/null +++ b/tests/ci_clangformat/to_ifrt_sharding.cc @@ -0,0 +1,141 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/to_ifrt_sharding.h" + +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_device_list.h" +#include "jaxlib/sharding.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/pjrt_ifrt/xla_sharding.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +namespace nb = ::nanobind; + +// Gets `xla::HloSharding` from a JAX Sharding. +xla::HloSharding GetXlaHloSharding(nb::handle sharding, + int64_t num_dimensions) { + if (sharding.type().is(nb::handle(jax::GSPMDSharding::type().ptr()))) { + return nb::cast(nb::handle(sharding.ptr())) + ->hlo_sharding(); + } else { + return nb::cast( + sharding.attr("_to_xla_hlo_sharding")(num_dimensions)); + } +} + +// Gets `xla::ifrt::DeviceList` from a JAX Sharding. +absl::StatusOr GetIfrtDeviceList( + nb::handle sharding_py) { + TF_ASSIGN_OR_RETURN(auto py_device_list, jax::GetPyDeviceList(sharding_py)); + return py_device_list->ifrt_device_list(); +} + +// Gets `xla::ifrt::MemoryKind` from a JAX Sharding. +xla::ifrt::MemoryKind GetMemoryKind(nb::handle sharding) { + nb::object py_memory_kind = nb::none(); + + // sharding.attr("memory_kind") can crash if sharding was originally created + // from C++ and casted into a Python Sharding object. Thus, we cast sharding + // to a C++ type and use C++ `memory_kind()` method, which bypasses any Python + // attribute access. + nb::handle type = sharding.type(); + if (type.is(jax::NamedSharding::type())) { + py_memory_kind = + nb::cast(sharding)->memory_kind(); + } else if (type.is(jax::SingleDeviceSharding::type())) { + py_memory_kind = + nb::cast(sharding)->memory_kind(); + } else if (type.is(jax::GSPMDSharding::type())) { + py_memory_kind = + nb::cast(sharding)->memory_kind(); + } else { + py_memory_kind = sharding.attr("memory_kind"); + } + + if (py_memory_kind.is_none()) { + return xla::ifrt::MemoryKind(); + } + return xla::ifrt::MemoryKind(nb::cast(py_memory_kind)); +} + +// Converts a JAX Sharding into `xla::ifrt::HloSharding`. +absl::StatusOr GetIfrtHloSharding( + nb::handle sharding, const xla::ifrt::Shape &shape) { + TF_ASSIGN_OR_RETURN(xla::ifrt::DeviceListRef device_list, + GetIfrtDeviceList(sharding)); + xla::ifrt::MemoryKind memory_kind = GetMemoryKind(sharding.ptr()); + xla::HloSharding hlo_sharding = + GetXlaHloSharding(sharding, shape.dims().size()); + return xla::ifrt::HloSharding::Create( + std::move(device_list), std::move(memory_kind), std::move(hlo_sharding)); +} + +// Converts a JAX Sharding into `xla::ifrt::ConcreteEvenSharding`. +absl::StatusOr GetIfrtConcreteEvenSharding( + nb::handle sharding, xla::ifrt::DType dtype, + const xla::ifrt::Shape &shape) { + TF_ASSIGN_OR_RETURN(xla::ifrt::DeviceListRef device_list, + GetIfrtDeviceList(sharding)); + xla::ifrt::MemoryKind memory_kind = GetMemoryKind(sharding.ptr()); + TF_ASSIGN_OR_RETURN(xla::PrimitiveType xla_primitive_type, + xla::ifrt::ToPrimitiveType(dtype)); + // The XLA shape's layout is irrelevant because we only need to know the + // tile shape, which is independent from the layout. + xla::Shape xla_shape = xla::ShapeUtil::MakeShapeWithDescendingLayout( + xla_primitive_type, shape.dims()); + xla::HloSharding hlo_sharding = + GetXlaHloSharding(sharding, shape.dims().size()); + xla::Shape tile_shape = hlo_sharding.TileShape(xla_shape); + xla::ifrt::Shape shard_shape(xla::ifrt::Shape::Dimensions( + tile_shape.dimensions().begin(), tile_shape.dimensions().end())); + return xla::ifrt::ConcreteEvenSharding::Create( + std::move(device_list), std::move(memory_kind), shape, + /*shard_shape=*/std::move(shard_shape)); +} + +// Converts a JAX Sharding into `xla::ifrt::ConcreteSharding`. +absl::StatusOr GetIfrtConcreteSharding( + nb::handle sharding, const xla::ifrt::Shape &shape, + std::vector shard_shapes) { + TF_ASSIGN_OR_RETURN(xla::ifrt::DeviceListRef device_list, + GetIfrtDeviceList(sharding)); + xla::ifrt::MemoryKind memory_kind = GetMemoryKind(sharding.ptr()); + return xla::ifrt::ConcreteSharding::Create( + std::move(device_list), std::move(memory_kind), shape, + /*shard_shapes=*/std::move(shard_shapes)); +} + +} // namespace xla diff --git a/tests/ci_clangformat/to_ifrt_sharding.h b/tests/ci_clangformat/to_ifrt_sharding.h new file mode 100644 index 0000000..54b7afc --- /dev/null +++ b/tests/ci_clangformat/to_ifrt_sharding.h @@ -0,0 +1,61 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_TO_IFRT_SHARDING_H_ +#define JAXLIB_TO_IFRT_SHARDING_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" + +namespace xla { + +// Gets `xla::HloSharding` from a JAX Sharding. +xla::HloSharding GetXlaHloSharding(nanobind::handle sharding, + int64_t num_dimensions); + +// Gets `xla::ifrt::DeviceList` from a JAX Sharding. +absl::StatusOr GetIfrtDeviceList( + nanobind::handle sharding_py); + +// Gets `xla::ifrt::MemoryKind` from a JAX Sharding. +xla::ifrt::MemoryKind GetMemoryKind(nanobind::handle sharding); + +// Converts a JAX Sharding into `xla::ifrt::HloSharding`. +absl::StatusOr GetIfrtHloSharding( + nanobind::handle sharding, const xla::ifrt::Shape &shape); + +// Converts a JAX Sharding into `xla::ifrt::ConcreteEvenSharding`. +absl::StatusOr GetIfrtConcreteEvenSharding( + nanobind::handle sharding, xla::ifrt::DType dtype, + const xla::ifrt::Shape &shape); + +// Converts a JAX Sharding into `xla::ifrt::ConcreteSharding`. +absl::StatusOr GetIfrtConcreteSharding( + nanobind::handle sharding, const xla::ifrt::Shape &shape, + std::vector shard_shapes); + +} // namespace xla + +#endif // JAXLIB_TO_IFRT_SHARDING_H_ diff --git a/tests/ci_clangformat/traceback.cc b/tests/ci_clangformat/traceback.cc new file mode 100644 index 0000000..44ad8c8 --- /dev/null +++ b/tests/ci_clangformat/traceback.cc @@ -0,0 +1,357 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/traceback.h" + +#include + +#include +#include +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/hash/hash.h" +#include "absl/log/check.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "jaxlib/nb_class_ptr.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "tsl/platform/platform.h" +#include "xla/pjrt/exceptions.h" + +#ifdef PLATFORM_GOOGLE +#define Py_BUILD_CORE +#include "internal/pycore_frame.h" +#undef Py_BUILD_CORE +#endif // PLATFORM_GOOGLE + +namespace xla { + +namespace nb = nanobind; + +bool Traceback::enabled_ = true; + +Traceback::Traceback() { + DCHECK(PyGILState_Check()); + PyThreadState *thread_state = PyThreadState_GET(); + +#if PY_VERSION_HEX < 0x030b0000 + // The representation of frame->f_lasti changed from bytes to words in Python + // 3.10, see https://docs.python.org/3/whatsnew/3.10.html#changes-in-the-c-api + // This should match sizeof(_Py_CODEUNIT) which is unfortunately private. + constexpr int kLastiWordBytes = 2; + + for (PyFrameObject *py_frame = thread_state->frame; py_frame != nullptr; + py_frame = py_frame->f_back) { + Py_INCREF(py_frame->f_code); + frames_.emplace_back(py_frame->f_code, py_frame->f_lasti * kLastiWordBytes); + } +#else // PY_VERSION_HEX < 0x030b0000 + +#ifdef PLATFORM_GOOGLE + // This code is equivalent to the version using public APIs, but it saves us + // an allocation of one object per stack frame. However, this is definitely + // violating the API contract of CPython, so we only use this where we can be + // confident we know exactly which CPython we are using (internal to Google). + // Feel free to turn this on if you like, but it might break at any time! + for (_PyInterpreterFrame *f = thread_state->cframe->current_frame; + f != nullptr; f = f->previous) { + if (_PyFrame_IsIncomplete(f)) continue; + Py_INCREF(f->f_code); + frames_.emplace_back(f->f_code, + _PyInterpreterFrame_LASTI(f) * sizeof(_Py_CODEUNIT)); + } +#else // PLATFORM_GOOGLE + PyFrameObject *next; + for (PyFrameObject *py_frame = PyThreadState_GetFrame(thread_state); + py_frame != nullptr; py_frame = next) { + frames_.emplace_back(PyFrame_GetCode(py_frame), PyFrame_GetLasti(py_frame)); + next = PyFrame_GetBack(py_frame); + Py_XDECREF(py_frame); + } +#endif // PLATFORM_GOOGLE + +#endif // PY_VERSION_HEX < 0x030b0000 +} + +Traceback::~Traceback() { + for (auto &frame : frames_) { + DCHECK(PyGILState_Check()); + Py_DECREF(frame.first); + } +} + +Traceback::Traceback(Traceback &&other) noexcept + : frames_(std::move(other.frames_)) { + // absl::InlinedVector does not always clear itself if moved. Since we rely on + // its empty() method to destroy Traceback differently, we explicitly clear + // here. + other.frames_.clear(); +} + +std::string Traceback::Frame::ToString() const { + return absl::StrFormat("%s:%d (%s)", nb::cast(file_name), + line_num, nb::cast(function_name)); +} + +std::string Traceback::ToString() const { + std::vector frame_strs; + frame_strs.reserve(frames_.size()); + for (const Frame &frame : Frames()) { + frame_strs.push_back(frame.ToString()); + } + return absl::StrJoin(frame_strs, "\n"); +} + +std::vector Traceback::Frames() const { + // We require the GIL because we manipulate Python strings. + CHECK(PyGILState_Check()); + std::vector frames; + frames.reserve(frames_.size()); + for (const auto &frame : frames_) { + frames.push_back(Frame{nb::borrow(frame.first->co_filename), + nb::borrow(frame.first->co_name), + frame.first->co_firstlineno, + PyCode_Addr2Line(frame.first, frame.second)}); + } + return frames; +} + +std::optional> Traceback::Get() { + DCHECK(PyGILState_Check()); + if (!enabled_) { + return std::nullopt; + } + return make_nb_class(); +} + +void Traceback::SetEnabled(bool enabled) { enabled_ = enabled; } + +nb::object Traceback::AsPythonTraceback() const { + nb::object traceback = nb::none(); + nb::dict globals; + nb::handle traceback_type(reinterpret_cast(&PyTraceBack_Type)); + for (const std::pair &frame : frames_) { + int lineno = PyCode_Addr2Line(frame.first, frame.second); + // Under Python 3.11 we observed crashes when using a fake PyFrameObject + // with a real PyCodeObject (https://github.com/google/jax/issues/16027). + // because the frame does not have fields necessary to compute the locals, + // notably the closure object, leading to crashes in CPython in + // _PyFrame_FastToLocalsWithError + // https://github.com/python/cpython/blob/deaf509e8fc6e0363bd6f26d52ad42f976ec42f2/Objects/frameobject.c#LL1116C2-L1116C2 + // We therefore always build a fake code object to go along with our fake + // frame. + PyCodeObject *py_code = + PyCode_NewEmpty(PyUnicode_AsUTF8(frame.first->co_filename), + PyUnicode_AsUTF8(frame.first->co_name), lineno); + PyFrameObject *py_frame = PyFrame_New(PyThreadState_Get(), py_code, + globals.ptr(), /*locals=*/nullptr); + Py_DECREF(py_code); + + traceback = traceback_type( + /*tb_next=*/std::move(traceback), + /*tb_frame=*/ + nb::steal(reinterpret_cast(py_frame)), + /*tb_lasti=*/0, + /*tb_lineno=*/ + PyCode_Addr2Line(frame.first, frame.second)); + } + return traceback; +} + +namespace { + +Py_hash_t traceback_tp_hash(PyObject *o) { + Traceback *tb; + if (!nb::try_cast(nb::handle(o), tb)) { + PyErr_SetString(PyExc_TypeError, "Expected a Traceback object"); + return -1; + } + size_t h = absl::HashOf(*tb); + Py_hash_t s = absl::bit_cast(h); // Python hashes are signed. + return s == -1 ? -2 : s; // -1 must not be used as a Python hash value. +} + +PyObject *traceback_tp_richcompare(PyObject *self, PyObject *other, int op) { + if (op != Py_EQ && op != Py_NE) { + return Py_NewRef(Py_NotImplemented); + } + + Traceback *x; + if (!nb::try_cast(nb::handle(self), x)) { + PyErr_SetString(PyExc_TypeError, "Expected a Traceback object"); + return nullptr; + } + + bool result; + Traceback *y; + if (nb::try_cast(nb::handle(other), y)) { + result = ((*x == *y) == (op == Py_EQ)); + } else { + result = (op == Py_NE); + } + return Py_NewRef(result ? Py_True : Py_False); +} + +// It turns out to be slightly faster to define a tp_hash slot rather than +// defining __hash__ and __eq__ on the class. +PyType_Slot traceback_slots_[] = { + {Py_tp_hash, (void *)traceback_tp_hash}, + {Py_tp_richcompare, (void *)traceback_tp_richcompare}, + {0, nullptr}, +}; + +} // namespace + +void BuildTracebackSubmodule(nb::module_ &m) { + nb::class_(m, "Frame") + .def(nb::init()) + .def_ro("file_name", &Traceback::Frame::file_name) + .def_ro("function_name", &Traceback::Frame::function_name) + .def_ro("function_start_line", &Traceback::Frame::function_start_line) + .def_ro("line_num", &Traceback::Frame::line_num) + .def("__repr__", [](const Traceback::Frame &frame) { + return absl::StrFormat( + "%s;%s:%d", nb::cast(frame.function_name), + nb::cast(frame.file_name), frame.line_num); + }); + + nb::class_ traceback(m, "Traceback", + nb::type_slots(traceback_slots_), + "Represents a Python stack trace."); + traceback.def_prop_rw_static( + "enabled", [](nb::object /* cls */) { return Traceback::enabled(); }, + [](nb::object /* cls */, bool enabled) { + return Traceback::SetEnabled(enabled); + }); + traceback.def_static( + "get_traceback", []() { return Traceback::Get(); }, + R"doc( + Returns a :class:`Traceback` for the current thread. + + If ``Traceback.enabled`` is ``True``, returns a :class:`Traceback` object + that describes the Python stack of the calling thread. Stack trace + collection has a small overhead, so it is disabled by default. If traceback + collection is disabled, returns ``None``. + )doc"); + traceback.def_prop_ro("frames", &Traceback::Frames); + traceback.def("raw_frames", [](const Traceback &tb) -> nb::tuple { + // We return a tuple of lists, rather than a list of tuples, because it + // is cheaper to allocate only three Python objects for everything rather + // than one per frame. + nb::list out_code = nb::steal(PyList_New(tb.raw_frames().size())); + nb::list out_lasti = + nb::steal(PyList_New(tb.raw_frames().size())); + for (size_t i = 0; i < tb.raw_frames().size(); ++i) { + const auto &frame = tb.raw_frames()[i]; + PyObject *code = reinterpret_cast(frame.first); + Py_INCREF(code); + PyList_SET_ITEM(out_code.ptr(), i, code); + PyList_SET_ITEM(out_lasti.ptr(), i, + nb::int_(frame.second).release().ptr()); + } + return nb::make_tuple(out_code, out_lasti); + }); + traceback.def("__str__", &Traceback::ToString); + traceback.def("as_python_traceback", &Traceback::AsPythonTraceback); + + traceback.def_static( + "traceback_from_frames", + [](std::vector frames) { + nb::object traceback = nb::none(); + nb::dict globals; + nb::handle traceback_type( + reinterpret_cast(&PyTraceBack_Type)); + for (const Traceback::Frame &frame : frames) { + PyCodeObject *py_code = + PyCode_NewEmpty(frame.file_name.c_str(), + frame.function_name.c_str(), frame.line_num); + PyFrameObject *py_frame = PyFrame_New(PyThreadState_Get(), py_code, + globals.ptr(), /*locals=*/ + nullptr); + Py_DECREF(py_code); + traceback = traceback_type( + /*tb_next=*/std::move(traceback), + /*tb_frame=*/ + nb::steal(reinterpret_cast(py_frame)), + /*tb_lasti=*/0, + /*tb_lineno=*/ + frame.line_num); + } + return traceback; + }, + "Creates a traceback from a list of frames."); + + traceback.def_static( + "code_addr2line", + [](nb::handle code, int lasti) { + if (!PyCode_Check(code.ptr())) { + throw xla::XlaRuntimeError("code argument must be a code object"); + } + return PyCode_Addr2Line(reinterpret_cast(code.ptr()), + lasti); + }, + "Python wrapper around the Python C API function PyCode_Addr2Line"); + +#if PY_VERSION_HEX >= 0x030b0000 + traceback.def_static( + "code_addr2location", + [](nb::handle code, int lasti) { + if (!PyCode_Check(code.ptr())) { + throw xla::XlaRuntimeError("code argument must be a code object"); + } + int start_line, start_column, end_line, end_column; + if (!PyCode_Addr2Location(reinterpret_cast(code.ptr()), + lasti, &start_line, &start_column, &end_line, + &end_column)) { + throw nb::python_error(); + } + return nb::make_tuple(start_line, start_column, end_line, end_column); + }, + "Python wrapper around the Python C API function PyCode_Addr2Location"); +#endif // PY_VERSION_HEX >= 0x030b0000 + +#if PY_VERSION_HEX < 0x030b0000 + // This function replaces the exception traceback associated with the current + // Python thread. + m.def( + "replace_thread_exc_traceback", + [](nb::object tb) { + if (!tb.is_none() && !PyTraceBack_Check(tb.ptr())) { + throw xla::XlaRuntimeError( + "argument must be a traceback object or None"); + } + PyThreadState *thread_state = PyThreadState_Get(); + if (!thread_state->exc_info->exc_traceback) { + throw xla::XlaRuntimeError( + "Current thread does not have an active " + "exception traceback"); + } + PyObject *old_exc_traceback = thread_state->exc_info->exc_traceback; + PyObject *new_tb = tb.is_none() ? nullptr : tb.release().ptr(); + thread_state->exc_info->exc_traceback = new_tb; + Py_XDECREF(old_exc_traceback); + }, + nb::arg("traceback").none()); +#endif // PY_VERSION_HEX < 0x30b0000 +} +} // namespace xla diff --git a/tests/ci_clangformat/traceback.h b/tests/ci_clangformat/traceback.h new file mode 100644 index 0000000..993839b --- /dev/null +++ b/tests/ci_clangformat/traceback.h @@ -0,0 +1,109 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_TRACEBACK_H_ +#define JAXLIB_TRACEBACK_H_ + +#include + +#include +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/container/inlined_vector.h" +#include "jaxlib/nb_class_ptr.h" +#include "nanobind/nanobind.h" + +namespace xla { + +// Represents a Python traceback. This object is designed to be allocated on +// the Python heap; creating or destroying a traceback requires the GIL. +class Traceback { + public: + // Requires GIL. Creates a Traceback object that requires destructor to be + // invoked with GIL held as well. + static std::optional> Get(); + + // Requires GIL. + static bool enabled() { return enabled_; } + // Requires GIL. + static void SetEnabled(bool enabled); + + // Requires GIL. Don't call this directly, you're looking for Get(). + Traceback(); + // Requires GIL. + ~Traceback(); + + Traceback(const Traceback &) = delete; + Traceback(Traceback &&other) noexcept; + Traceback &operator=(const Traceback &) = delete; + Traceback &operator=(Traceback &&) = delete; + + // Requires the GIL be held. + std::string ToString() const; + + struct Frame { + nanobind::str file_name; + nanobind::str function_name; + int function_start_line; + int line_num; + + std::string ToString() const; + }; + std::vector Frames() const; + + const absl::InlinedVector, 32> &raw_frames() + const { + return frames_; + } + + // Returns the traceback as a fake Python Traceback object, suitable for + // using as an exception traceback. + nanobind::object AsPythonTraceback() const; + + bool operator==(const Traceback &other) const { + return frames_ == other.frames_; + } + bool operator!=(const Traceback &other) const { + return frames_ != other.frames_; + } + + private: + // Each frame is a pair of a code object and a "lasti" instruction location + // in bytes. The size of _Py_CODEUNIT has changed across different Python + // versions; the lasti value here has already been multiplied by + // sizeof(_Py_CODEUNIT) if needed and is suitable for passing to functions + // like PyCode_Addr2Line(). + absl::InlinedVector, 32> frames_; + + // Protected by GIL. + static bool enabled_; +}; + +using nb_traceback = nb_class_ptr; + +template +H AbslHashValue(H h, const Traceback &traceback) { + h = H::combine(std::move(h), traceback.raw_frames()); + return h; +} + +void BuildTracebackSubmodule(nanobind::module_ &m); + +} // namespace xla + +#endif // JAXLIB_TRACEBACK_H_ diff --git a/tests/ci_clangformat/util.cc b/tests/ci_clangformat/util.cc new file mode 100644 index 0000000..dc6ee2f --- /dev/null +++ b/tests/ci_clangformat/util.cc @@ -0,0 +1,85 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/util.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/synchronization/notification.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/value.h" +#include "xla/python/version.h" +#include "xla/tsl/concurrency/async_value.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/util.h" + +namespace xla { + +void BlockUntilReadyWithCancel(xla::PjRtFuture<> &future) { +#if JAX_IFRT_VERSION_NUMBER >= 5 + future.BlockUntilReady([](tsl::AsyncValue *value) { + auto state = std::make_shared(); + value->AndThen([state]() { state->Notify(); }); + while (true) { + if (state->WaitForNotificationWithTimeout(absl::Milliseconds(200))) { + break; + } + nanobind::gil_scoped_acquire gil_acquire; + if (PyErr_CheckSignals() != 0) { + throw nanobind::python_error(); + } + } + }); +#endif +} + +absl::Status AwaitBuffersReady(absl::Span ifrt_arrays) { + if (ifrt_arrays.empty()) { + return absl::OkStatus(); + } + + ifrt::Future<> future; + if (ifrt_arrays.size() == 1) { + future = ifrt_arrays[0]->GetReadyFuture(); + } else { + std::vector values; + values.reserve(ifrt_arrays.size()); + for (ifrt::Array *const ifrt_array : ifrt_arrays) { + values.push_back(tsl::FormRef(ifrt_array)); + } + ifrt::Client *const client = ifrt_arrays.front()->client(); + future = client->GetReadyFuture(values); + } + BlockUntilReadyWithCancel(future); + absl::Status s = future.Await(); + if (!s.ok()) { + // Fix up error string because some clients rely on it. + if (s.message() == "GetReadyFuture() called on deleted or donated buffer") { + s = InvalidArgument( + "BlockHostUntilReady() called on deleted or donated buffer"); + } + } + return s; +} + +} // namespace xla diff --git a/tests/ci_clangformat/util.h b/tests/ci_clangformat/util.h new file mode 100644 index 0000000..6b9726d --- /dev/null +++ b/tests/ci_clangformat/util.h @@ -0,0 +1,34 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_UTIL_H_ +#define JAXLIB_UTIL_H_ + +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "xla/python/ifrt/array.h" + +namespace xla { + +// Waits until future is ready but will cancel if ctrl-c is pressed. +void BlockUntilReadyWithCancel(xla::PjRtFuture<> &future); + +// Requests if given buffers are ready, awaits for results and returns OK if +// all of the buffers are ready or the last non-ok status. +absl::Status AwaitBuffersReady(absl::Span ifrt_arrays); + +} // namespace xla + +#endif // JAXLIB_UTIL_H_ diff --git a/tests/ci_clangformat/utils.cc b/tests/ci_clangformat/utils.cc new file mode 100644 index 0000000..b18cc47 --- /dev/null +++ b/tests/ci_clangformat/utils.cc @@ -0,0 +1,300 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include +#include + +#include "absl/cleanup/cleanup.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/synchronization/mutex.h" +#include "nanobind/nanobind.h" + +namespace nb = nanobind; + +namespace { + +// A variant of map(...) that: +// a) returns a list instead of an iterator, and +// b) checks that the input iterables are of equal length. +PyObject *SafeMap(PyObject *self, PyObject *const *args, Py_ssize_t nargs) { + if (nargs < 2) { + PyErr_SetString(PyExc_TypeError, "safe_map requires at least 2 arguments"); + return nullptr; + } + PyObject *fn = args[0]; + absl::InlinedVector iterators; + iterators.reserve(nargs - 1); + for (Py_ssize_t i = 1; i < nargs; ++i) { + PyObject *it = PyObject_GetIter(args[i]); + if (!it) return nullptr; + iterators.push_back(nb::steal(it)); + } + + // Try to use a length hint to estimate how large a list to allocate. + Py_ssize_t length_hint = PyObject_LengthHint(args[1], 2); + if (PyErr_Occurred()) { + PyErr_Clear(); + } + if (length_hint < 0) { + length_hint = 2; + } + + nb::list list = nb::steal(PyList_New(length_hint)); + int n = 0; // Current true size of the list + + // The arguments we will pass to fn. We allocate space for one more argument + // than we need at the start of the argument list so we can use + // PY_VECTORCALL_ARGUMENTS_OFFSET which may speed up the callee. + absl::InlinedVector values(nargs, nullptr); + while (true) { + absl::Cleanup values_cleanup = [&values]() { + for (PyObject *v : values) { + Py_XDECREF(v); + v = nullptr; + } + }; + values[1] = PyIter_Next(iterators[0].ptr()); + if (PyErr_Occurred()) return nullptr; + + if (values[1]) { + for (size_t i = 1; i < iterators.size(); ++i) { + values[i + 1] = PyIter_Next(iterators[i].ptr()); + if (PyErr_Occurred()) return nullptr; + if (!values[i + 1]) { + PyErr_Format(PyExc_ValueError, + "safe_map() argument %u is shorter than argument 1", + i + 1); + return nullptr; + } + } + } else { + // No more elements should be left. Checks the other iterators are + // exhausted. + for (size_t i = 1; i < iterators.size(); ++i) { + values[i + 1] = PyIter_Next(iterators[i].ptr()); + if (PyErr_Occurred()) return nullptr; + if (values[i + 1]) { + PyErr_Format(PyExc_ValueError, + "safe_map() argument %u is longer than argument 1", + i + 1); + return nullptr; + } + } + + // If the length hint was too large, truncate the list to the true size. + if (n < length_hint) { + if (PyList_SetSlice(list.ptr(), n, length_hint, nullptr) < 0) { + return nullptr; + } + } + return list.release().ptr(); + } + + nb::object out = nb::steal(PyObject_Vectorcall( + fn, &values[1], (nargs - 1) | PY_VECTORCALL_ARGUMENTS_OFFSET, + /*kwnames=*/nullptr)); + if (PyErr_Occurred()) { + return nullptr; + } + + if (n < length_hint) { + PyList_SET_ITEM(list.ptr(), n, out.release().ptr()); + } else { + if (PyList_Append(list.ptr(), out.ptr()) < 0) { + return nullptr; + } + } + ++n; + } +} + +PyMethodDef safe_map_def = { + "safe_map", + reinterpret_cast(SafeMap), + METH_FASTCALL, +}; + +// Similar to SafeMap, but ignores the return values of the function and returns +// None. +PyObject *ForEach(PyObject *self, PyObject *const *args, Py_ssize_t nargs) { + if (nargs < 2) { + PyErr_SetString(PyExc_TypeError, "foreach() requires at least 2 arguments"); + return nullptr; + } + PyObject *fn = args[0]; + absl::InlinedVector iterators; + iterators.reserve(nargs - 1); + for (Py_ssize_t i = 1; i < nargs; ++i) { + PyObject *it = PyObject_GetIter(args[i]); + if (!it) return nullptr; + iterators.push_back(nb::steal(it)); + } + + // The arguments we will pass to fn. We allocate space for one more argument + // than we need at the start of the argument list so we can use + // PY_VECTORCALL_ARGUMENTS_OFFSET which may speed up the callee. + absl::InlinedVector values(nargs, nullptr); + while (true) { + absl::Cleanup values_cleanup = [&values]() { + for (PyObject *v : values) { + Py_XDECREF(v); + v = nullptr; + } + }; + values[1] = PyIter_Next(iterators[0].ptr()); + if (PyErr_Occurred()) return nullptr; + + if (values[1]) { + for (size_t i = 1; i < iterators.size(); ++i) { + values[i + 1] = PyIter_Next(iterators[i].ptr()); + if (PyErr_Occurred()) return nullptr; + if (!values[i + 1]) { + PyErr_Format(PyExc_ValueError, + "foreach() argument %u is shorter than argument 1", + i + 1); + return nullptr; + } + } + } else { + // No more elements should be left. Checks the other iterators are + // exhausted. + for (size_t i = 1; i < iterators.size(); ++i) { + values[i + 1] = PyIter_Next(iterators[i].ptr()); + if (PyErr_Occurred()) return nullptr; + if (values[i + 1]) { + PyErr_Format(PyExc_ValueError, + "foreach() argument %u is longer than argument 1", + i + 1); + return nullptr; + } + } + Py_INCREF(Py_None); + return Py_None; + } + + nb::object out = nb::steal(PyObject_Vectorcall( + fn, &values[1], (nargs - 1) | PY_VECTORCALL_ARGUMENTS_OFFSET, + /*kwnames=*/nullptr)); + if (PyErr_Occurred()) { + return nullptr; + } + } +} + +PyMethodDef foreach_def = { + "foreach", reinterpret_cast(ForEach), METH_FASTCALL, + "foreach() applies a function elementwise to one or more iterables, " + "ignoring the return values and returns None. The iterables must all have " + "the same lengths."}; + +nb::list TopologicalSort(nb::str parents_attr, + nb::iterable end_nodes_iterable) { + // This is a direct conversion of the original Python implementation. + // More efficient implementations of a topological sort are possible (and + // indeed, easier to write), but changing the choice of topological order + // would break existing tests. + std::vector end_nodes; + absl::flat_hash_set seen; + for (nb::handle n : end_nodes_iterable) { + nb::object node = nb::borrow(n); + if (seen.insert(node.ptr()).second) { + end_nodes.push_back(node); + } + } + + nb::list sorted_nodes; + if (end_nodes.empty()) { + return sorted_nodes; + } + + std::vector stack = end_nodes; + absl::flat_hash_map child_counts; + while (!stack.empty()) { + nb::object node = std::move(stack.back()); + stack.pop_back(); + auto &count = child_counts[node.ptr()]; + if (count == 0) { + for (nb::handle parent : node.attr(parents_attr)) { + stack.push_back(nb::borrow(parent)); + } + } + ++count; + } + + for (nb::handle n : end_nodes) { + child_counts[n.ptr()] -= 1; + } + + std::vector childless_nodes; + childless_nodes.reserve(end_nodes.size()); + for (nb::handle n : end_nodes) { + if (child_counts[n.ptr()] == 0) { + childless_nodes.push_back(nb::borrow(n)); + } + } + + while (!childless_nodes.empty()) { + nb::object node = std::move(childless_nodes.back()); + childless_nodes.pop_back(); + sorted_nodes.append(node); + for (nb::handle parent : node.attr(parents_attr)) { + auto &count = child_counts[parent.ptr()]; + if (count == 1) { + childless_nodes.push_back(nb::borrow(parent)); + } else { + --count; + } + } + } + sorted_nodes.reverse(); + return sorted_nodes; +} + +} // namespace + +NB_MODULE(utils, m) { + nb::object module_name = m.attr("__name__"); + m.attr("safe_map") = nb::steal( + PyCFunction_NewEx(&safe_map_def, /*self=*/nullptr, module_name.ptr())); + m.attr("foreach") = nb::steal( + PyCFunction_NewEx(&foreach_def, /*self=*/nullptr, module_name.ptr())); + + m.def("topological_sort", &TopologicalSort, nb::arg("parents_attr"), + nb::arg("end_nodes"), + "Computes a topological sort of a graph of objects. parents_attr is " + "the name of the attribute on each object that contains the list of " + "parent objects. end_nodes is an iterable of objects from which we " + "should start a backwards search."); + + // Python has no reader-writer lock in its standard library, so we expose + // bindings around absl::Mutex. + nb::class_(m, "Mutex") + .def(nb::init<>()) + .def("lock", &absl::Mutex::Lock, nb::call_guard()) + .def("unlock", &absl::Mutex::Unlock) + .def("assert_held", &absl::Mutex::AssertHeld) + .def("reader_lock", &absl::Mutex::ReaderLock, + nb::call_guard()) + .def("reader_unlock", &absl::Mutex::ReaderUnlock) + .def("assert_reader_held", &absl::Mutex::AssertReaderHeld) + .def("writer_lock", &absl::Mutex::WriterLock, + nb::call_guard()) + .def("writer_unlock", &absl::Mutex::WriterUnlock); +} \ No newline at end of file diff --git a/tests/ci_clangformat/weakref_lru_cache.cc b/tests/ci_clangformat/weakref_lru_cache.cc new file mode 100644 index 0000000..8a10cfa --- /dev/null +++ b/tests/ci_clangformat/weakref_lru_cache.cc @@ -0,0 +1,416 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/cleanup/cleanup.h" +#include "absl/hash/hash.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/pjrt/lru_cache.h" +#include "xla/tsl/platform/logging.h" + +namespace nb = nanobind; + +namespace jax { +namespace { + +// Minimal wrapper to expose a nb::dict_iterator's value as something +// hashable with Abseil. +class HashablePyDictEntry { + public: + explicit HashablePyDictEntry(std::pair entry) + : entry_(entry) {} + + template + friend H AbslHashValue(H h, const HashablePyDictEntry &v) { + return H::combine(std::move(h), nb::hash(v.entry_.first), + nb::hash(v.entry_.second)); + } + + std::pair entry_; +}; + +// Similarly, a minimalist adaptor around the nb::detail::dict_iterator +// itself. Note that the iterator "is" also a Value. Does not meet the full +// standard iterator requirements, only enough to support H::combine_unordered. +class HashablePyDictIter { + public: + using iterator_category = std::input_iterator_tag; + + explicit HashablePyDictIter(nb::detail::dict_iterator &iter) : iter_(iter) {} + + // Minimal set of iterator operations. + HashablePyDictEntry operator*() const { return HashablePyDictEntry(*iter_); } + bool operator!=(const HashablePyDictIter &rhs) const { + return iter_ != rhs.iter_; + } + void operator++() { ++iter_; } + + private: + nb::detail::dict_iterator &iter_; +}; + +struct HashableKey { + nb::object context; + nb::args args; + nb::kwargs kwargs; + + template + friend H AbslHashValue(H h, const HashableKey &key) { + // Note: Despite the fact this is an ABSL hash function, it's safe to call + // functions that may throw exceptions such as nb::hash(), because it is + // used by an LRUCache, which uses a std::unordered_map, which is + // exception-safe. + h = H::combine(std::move(h), nb::hash(key.context), nb::hash(key.args)); + nb::detail::dict_iterator begin = key.kwargs.begin(); + nb::detail::dict_iterator end = key.kwargs.end(); + h = H::combine_unordered(std::move(h), HashablePyDictIter(begin), + HashablePyDictIter(end)); + h = H::combine(std::move(h), key.kwargs.size()); + return h; + } +}; + +} // namespace + +class WeakrefLRUCache : public std::enable_shared_from_this { + public: + WeakrefLRUCache(nb::callable cache_context_fn, nb::callable fn, + int64_t maxsize) + : cache_context_fn_(cache_context_fn), fn_(fn), lru_list_(maxsize) {} + + nb::object Call(nb::object weakref_key, nb::args args, nb::kwargs kwargs); + + std::vector GetKeys(); + + struct CacheInfo { + int64_t hits; + int64_t misses; + int64_t maxsize; + int64_t currsize; + }; + CacheInfo GetCacheInfo() const; + + void Clear(); + + static PyType_Slot slots_[]; + + private: + class Key { + public: + Key(nb::object context, nb::args args, nb::kwargs kwargs) + : context_(std::move(context)), + args_(std::move(args)), + kwargs_(std::move(kwargs)), + cached_hash_(absl::HashOf(HashableKey{context_, args_, kwargs_})) {} + + bool operator==(const Key &other) const { + return context_.equal(other.context_) && args_.equal(other.args_) && + kwargs_.equal(other.kwargs_); + } + + template + friend H AbslHashValue(H h, const Key &key) { + return H::combine(std::move(h), key.cached_hash_); + } + + nb::object context() const { return context_; } + nb::args args() const { return args_; } + nb::kwargs kwargs() const { return kwargs_; } + + int tp_traverse(visitproc visit, void *arg) const { + Py_VISIT(context_.ptr()); + Py_VISIT(args_.ptr()); + Py_VISIT(kwargs_.ptr()); + return 0; + } + + private: + nb::object context_; + nb::args args_; + nb::kwargs kwargs_; + size_t cached_hash_; + }; + + struct CacheEntry { + bool has_result = false; + nb::object result; + absl::Notification completed; + std::thread::id thread_id = std::this_thread::get_id(); + + int tp_traverse(visitproc visit, void *arg) const { + Py_VISIT(result.ptr()); + return 0; + } + }; + + struct WeakrefCacheKey { + nb::weakref ref; + size_t cached_hash; + }; + + using Cache = xla::LRUCache>; + + struct WeakrefCacheValue { + std::shared_ptr cache; + }; + + struct WeakrefKeyHash { + size_t operator()(const WeakrefCacheKey &v) const { return v.cached_hash; } + }; + + struct WeakrefKeyEq { + bool operator()(const WeakrefCacheKey &lhs, + const WeakrefCacheKey &rhs) const { + return lhs.ref.equal(rhs.ref); + } + }; + + std::shared_ptr GetCache(WeakrefCacheKey key) { + WeakrefCacheValue &value = entries_[key]; + if (!value.cache) { + value.cache = std::make_shared(&lru_list_); + } + return value.cache; + } + + nb::callable cache_context_fn_; + nb::callable fn_; + Cache::LRUList lru_list_; + std::unordered_map + entries_; + int64_t misses_ = 0; + int64_t total_queries_ = 0; + absl::Mutex mu_; + + static int tp_traverse(PyObject *self, visitproc visit, void *arg); + static int tp_clear(PyObject *self); +}; + +nb::object WeakrefLRUCache::Call(nb::object weakref_key, nb::args args, + nb::kwargs kwargs) + ABSL_NO_THREAD_SAFETY_ANALYSIS { + nb::object context = cache_context_fn_(); + + // We precompute all of the hash values needed by the various maps rather + // than computing them during the std::unordered_map insertions. At the very + // least, MSVC's std::unordered_map has undefined behavior if the hash + // function throws an exception + // (https://learn.microsoft.com/en-us/cpp/standard-library/unordered-map-class?view=msvc-170#emplace). + Key key(context, args, kwargs); + size_t wrcache_hash = static_cast(nb::hash(weakref_key)); + + // No hash computations after this point. + + auto weakref_gc_callback = nb::cpp_function( + [this_weak = weak_from_this(), wrcache_hash](nb::handle weakref) { + auto cache = this_weak.lock(); + if (cache == nullptr) { + return; + } + // Set up PyCriticalSection for cache python associated object; + auto py_cache = nb::find(cache); + // This should never happen as python cache should always be found + CHECK(py_cache.ptr() != nullptr); + nb::ft_object_guard lock(py_cache); + + // The object the reference referred to is now in the process of being + // destroyed, so we cannot refer to its contents. Python weakref + // objects compare based on identity if the object they refer to is + // gone, so the hash lookup will work fine. + auto it = cache->entries_.find( + WeakrefCacheKey{nb::borrow(weakref), wrcache_hash}); + if (it == cache->entries_.end()) { + return; + } + // Create temp-var to avoid re-entrant erase. + auto tmp = std::move(it->second); + cache->entries_.erase(it); + }); + nb::weakref weakref = nb::weakref(weakref_key, weakref_gc_callback); + WeakrefCacheKey wrcache_key{weakref, wrcache_hash}; + std::shared_ptr cache_ptr = GetCache(wrcache_key); + Cache &cache = *cache_ptr; + ++total_queries_; + + bool inserted = false; + std::shared_ptr entry; + { + // Because the gil can be released during cache insertion, this forces + // the lock order to be mu_ then gil so we must release the gil first. + nb::gil_scoped_release release; + // Acquire a mutex to avoid problems where the gil is released during + // cache insertion and then a second thread invalidates the cache order. + mu_.Lock(); + } + { + // GetOrCreateIfAbsent calls into Python hash and equality functions, + // which may throw exceptions. The use of absl::Cleanup ensures mu_ is + // released if that happens. + absl::Cleanup unlock = [this]() ABSL_UNLOCK_FUNCTION(mu_) { mu_.Unlock(); }; + entry = cache.GetOrCreateIfAbsent(key, [&inserted](const Key &key) { + inserted = true; + return std::make_shared(); + }); + } + if (!entry->completed.HasBeenNotified()) { + if (inserted) { + ++misses_; + absl::Cleanup notify = [&] { entry->completed.Notify(); }; + entry->result = fn_(weakref_key, *args, **kwargs); + entry->has_result = true; + } else { + if (entry->thread_id == std::this_thread::get_id()) { + auto error_string = + absl::StrCat("Recursively calling ", + nb::cast(nb::repr(weakref_key)), + nb::cast(nb::repr(args))); + PyErr_SetString(PyExc_RecursionError, error_string.c_str()); + throw nb::python_error(); + } + nb::gil_scoped_release release; + entry->completed.WaitForNotification(); + } + } + + if (entry->has_result) { + return entry->result; + } else { + ++misses_; + return fn_(weakref_key, *args, **kwargs); + } +} + +std::vector WeakrefLRUCache::GetKeys() { + std::vector results; + mu_.Lock(); + for (const auto &wr_entry : entries_) { + for (const auto &rest : *wr_entry.second.cache) { + nb::tuple result = + nb::make_tuple(*wr_entry.first.ref, rest.first.context(), + rest.first.args(), rest.first.kwargs()); + results.push_back(std::move(result)); + } + } + mu_.Unlock(); + return results; +} + +WeakrefLRUCache::CacheInfo WeakrefLRUCache::GetCacheInfo() const { + CacheInfo result; + result.hits = total_queries_ - misses_; + result.misses = misses_; + result.maxsize = lru_list_.Capacity(); + result.currsize = lru_list_.Size(); + return result; +} + +void WeakrefLRUCache::Clear() { + total_queries_ = misses_ = 0; + std::vector> deferred_deletes; + deferred_deletes.reserve(entries_.size()); + for (auto &entry : entries_) { + deferred_deletes.emplace_back(entry.first, std::move(entry.second)); + } + entries_.clear(); + deferred_deletes.clear(); +} + +/*static*/ int WeakrefLRUCache::tp_traverse(PyObject *self, visitproc visit, + void *arg) { + Py_VISIT(Py_TYPE(self)); + if (!nb::inst_ready(self)) { + return 0; + } + WeakrefLRUCache *cache = nb::inst_ptr(self); + Py_VISIT(cache->cache_context_fn_.ptr()); + Py_VISIT(cache->fn_.ptr()); + for (const auto &[wr_key, wr_value] : cache->entries_) { + Py_VISIT(wr_key.ref.ptr()); + for (const auto &[key, cache_value] : *wr_value.cache) { + int rval = key.tp_traverse(visit, arg); + if (rval != 0) { + return rval; + } + if (cache_value.value.has_value()) { + cache_value.value->get()->tp_traverse(visit, arg); + } + } + } + return 0; +} + +/*static*/ int WeakrefLRUCache::tp_clear(PyObject *self) { + WeakrefLRUCache *cache = nb::inst_ptr(self); + cache->Clear(); + cache->cache_context_fn_.reset(); + cache->fn_.reset(); + return 0; +} + +/* static */ PyType_Slot WeakrefLRUCache::slots_[] = { + {Py_tp_traverse, (void *)WeakrefLRUCache::tp_traverse}, + {Py_tp_clear, (void *)WeakrefLRUCache::tp_clear}, + {0, nullptr}, +}; + +NB_MODULE(weakref_lru_cache, m) { + auto weakref_lru_cache = + nb::class_(m, "WeakrefLRUCache", + nb::is_weak_referenceable(), + nb::type_slots(WeakrefLRUCache::slots_)) + .def("__call__", &WeakrefLRUCache::Call, nb::lock_self()) + .def("cache_keys", &WeakrefLRUCache::GetKeys, nb::lock_self()) + .def("cache_info", &WeakrefLRUCache::GetCacheInfo, nb::lock_self()) + .def("cache_clear", &WeakrefLRUCache::Clear, nb::lock_self()); + nb::class_(weakref_lru_cache, + "WeakrefLRUCacheInfo") + .def_ro("hits", &WeakrefLRUCache::CacheInfo::hits) + .def_ro("misses", &WeakrefLRUCache::CacheInfo::misses) + .def_ro("maxsize", &WeakrefLRUCache::CacheInfo::maxsize) + .def_ro("currsize", &WeakrefLRUCache::CacheInfo::currsize) + .def("__repr__", [](WeakrefLRUCache::CacheInfo &info) { + return absl::StrCat( + "WeakrefLRUCache(hits=", info.hits, ", misses=", info.misses, + ", maxsize=", info.maxsize, ", currsize=", info.currsize, ")"); + }); + m.def( + "weakref_lru_cache", + [](nb::callable cache_context_fn, nb::callable fn, int64_t maxsize) { + return std::make_shared(cache_context_fn, fn, maxsize); + }, + nb::arg("cache_context_fn"), nb::arg("fn"), nb::arg("maxsize") = 2048); +} + +} // namespace jax diff --git a/tests/ci_clangformat/xla.cc b/tests/ci_clangformat/xla.cc new file mode 100644 index 0000000..b8c105d --- /dev/null +++ b/tests/ci_clangformat/xla.cc @@ -0,0 +1,984 @@ +/* Copyright 2019 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "jaxlib/ffi.h" +#include "jaxlib/ifrt_proxy.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_program.h" +#include "jaxlib/sdy.h" +#include "nanobind/nanobind.h" +#include "nanobind/nb_defs.h" +#include "nanobind/stl/function.h" // IWYU pragma: keep +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/set.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/unordered_map.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/backends/cpu/collectives/cpu_collectives.h" +#include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/distributed/client.h" +#include "xla/pjrt/distributed/distributed.h" +#include "xla/pjrt/distributed/protocol.pb.h" +#include "xla/pjrt/distributed/service.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/plugin/xla_cpu/cpu_client_options.h" +#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/topology.h" +#include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" +#include "xla/python/version.h" +#include "xla/tsl/python/lib/core/numpy.h" // NOLINT +#include "llvm/Support/Casting.h" + +#if defined(__linux__) +#include "gloo/transport/tcp/attr.h" +#include "gloo/transport/tcp/device.h" +#include "jaxlib/py_socket_transfer.h" +#include "xla/backends/cpu/collectives/gloo_collectives.h" +#include "xla/backends/cpu/collectives/gloo_kv_store.h" +#elif defined(__APPLE__) +#include "gloo/transport/uv/device.h" +#include "xla/backends/cpu/collectives/gloo_collectives.h" // NOLINT +#include "xla/backends/cpu/collectives/gloo_kv_store.h" // NOLINT +#endif // defined(__linux__) + +#if !defined(_WIN32) && !defined(PLATFORM_GOOGLE) +#include "xla/backends/cpu/collectives/mpi_collectives.h" +#endif // !_WIN32 && !PLATFORM_GOOGLE + +#include "jaxlib/config.h" +#include "jaxlib/custom_call_sharding.h" +#include "jaxlib/dlpack.h" +#include "jaxlib/guard_lib.h" +#include "jaxlib/jax_jit.h" +#include "jaxlib/mlir.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/pjit.h" +#include "jaxlib/pmap_lib.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_compile_only_client.h" +#include "jaxlib/py_device.h" +#include "jaxlib/py_device_list.h" +#include "jaxlib/py_executable.h" +#include "jaxlib/py_memory_space.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/pytree.h" +#include "jaxlib/sharding.h" +#include "jaxlib/traceback.h" +#include "jaxlib/xla_compiler.h" +#include "tsl/platform/platform.h" +#include "xla/hlo/builder/lib/approx_topk_shape.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_api.h" +#include "xla/pjrt/pjrt_c_api_client.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/logging.h" // IWYU pragma: keep +#include "xla/python/nb_absl_flat_hash_map.h" // IWYU pragma: keep +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/pjrt_topology.h" +#include "xla/python/pprof_profile_builder.h" +#include "xla/tsl/distributed_runtime/preemption/preemption_sync_manager.h" +#include "xla/tsl/platform/status.h" + +// TODO(phawkins): remove host_id properties after JAX is update to avoid them. + +namespace xla { +namespace { + +namespace nb = nanobind; + +bool IsOptimizedBuild() { +#if NDEBUG + return true; +#else + return false; +#endif // NDEBUG +} + +// Is*san reports whether the build is under that particular sanitizer. +bool IsAsan() { +#if defined(ADDRESS_SANITIZER) + return true; +#else // defined(ADDRESS_SANITIZER) + return false; +#endif +} + +bool IsMsan() { +#if defined(MEMORY_SANITIZER) + return true; +#else // defined(MEMORY_SANITIZER) + return false; +#endif +} + +bool IsTsan() { +#if defined(THREAD_SANITIZER) + return true; +#else // defined(THREAD_SANITIZER) + return false; +#endif +} + +// IsSanitized reports whether the build is under any sanitizer. +bool IsSanitized() { return IsAsan() || IsMsan() || IsTsan(); } + +} // namespace + +NB_MODULE(_jax, m) { + // Initialize ABSL logging because code within XLA uses it. +#ifndef PLATFORM_GOOGLE + InitializeAbslLogging(); +#endif // PLATFORM_GOOGLE + + // We seem to get a fair number of leak warnings from nanobind. It's unclear + // whether these are false positives or not. + nb::set_leak_warnings(false); + + tsl::ImportNumpy(); + + // Exceptions + nb::exception xla_runtime_error(m, "XlaRuntimeError", + PyExc_RuntimeError); + xla_runtime_error.attr("__doc__") = nb::str( + "Runtime errors thrown by the JAX runtime. While the JAX runtime may " + "raise other exceptions as well, most exceptions thrown by the runtime " + "are instances of this class."); + + // Types + nb::enum_(m, "PrimitiveType", nb::is_arithmetic()) + .value("PRIMITIVE_TYPE_INVALID", PRIMITIVE_TYPE_INVALID) + .value("PRED", PRED) + .value("S4", S4) + .value("S8", S8) + .value("S16", S16) + .value("S32", S32) + .value("S64", S64) + .value("U4", U4) + .value("U8", U8) + .value("U16", U16) + .value("U32", U32) + .value("U64", U64) + .value("F16", F16) + .value("F4E2M1FN", F4E2M1FN) + .value("F8E3M4", F8E3M4) + .value("F8E4M3", F8E4M3) + .value("F8E4M3FN", F8E4M3FN) + .value("F8E4M3B11FNUZ", F8E4M3B11FNUZ) + .value("F8E4M3FNUZ", F8E4M3FNUZ) + .value("F8E5M2", F8E5M2) + .value("F8E5M2FNUZ", F8E5M2FNUZ) + .value("F8E8M0FNU", F8E8M0FNU) + .value("BF16", BF16) + .value("F32", F32) + .value("F64", F64) + .value("C64", C64) + .value("C128", C128) + .value("TUPLE", TUPLE) + .value("OPAQUE_TYPE", OPAQUE_TYPE) + .value("TOKEN", TOKEN); + + // Must be before PyClient.compile. + BuildXlaCompilerSubmodule(m); + + PyDevice::RegisterPythonType(m); + PyMemorySpace::RegisterPythonType(m); + PyClient::RegisterPythonTypes(m); + + nb::enum_(m, "ArrayCopySemantics", + nb::is_arithmetic()) + .value("ALWAYS_COPY", ifrt::ArrayCopySemantics::kAlwaysCopy) + .value("REUSE_INPUT", ifrt::ArrayCopySemantics::kReuseInput) + .value("DONATE_INPUT", ifrt::ArrayCopySemantics::kDonateInput); + + nb::class_(m, "PjRtLayout") + .def("__str__", &PjRtLayout::ToString) + .def("__eq__", [](const PjRtLayout &layout, + const PjRtLayout &other) { return layout == other; }) + .def("__hash__", + [](const PjRtLayout &layout) { return absl::HashOf(layout); }) + .def("_xla_layout", &PjRtLayout::xla_layout) + .def("__getstate__", + [](const PjRtLayout &layout) -> nb::tuple { + absl::StatusOr serialized = layout.Serialize(); + ThrowIfError(serialized.status()); + return nb::make_tuple( + nb::bytes(serialized->data(), serialized->size())); + }) + .def("__setstate__", [](PjRtLayout *self, nb::tuple t) { + nb::bytes serialized = nb::cast(t[0]); + absl::StatusOr> layout = + PjRtLayout::Deserialize( + absl::string_view(serialized.c_str(), serialized.size())); + ThrowIfError(layout.status()); + new (self) PjRtLayout((*layout)->xla_layout()); + }); + + nb::class_ cpu_collectives(m, "CpuCollectives"); + + m.def( + "make_gloo_tcp_collectives", + [](std::shared_ptr distributed_client, + + std::optional hostname, + std::optional interface) + -> std::shared_ptr { +#if defined(__linux__) + std::shared_ptr kv_store = nullptr; + if (distributed_client != nullptr) { + kv_store = GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"cpu:"); + } + auto gloo_kv_store = std::make_unique(kv_store); + auto tcp_attrs = gloo::transport::tcp::attr(); + if (hostname) { + tcp_attrs.hostname = *hostname; + } + if (interface) { + tcp_attrs.iface = *interface; + } + auto tcp_device = gloo::transport::tcp::CreateDevice(tcp_attrs); + return std::make_shared(std::move(gloo_kv_store), + std::move(tcp_device)); +#elif defined(__APPLE__) + std::shared_ptr kv_store = nullptr; + if (distributed_client != nullptr) { + kv_store = GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"cpu:"); + } + auto gloo_kv_store = std::make_unique(kv_store); + auto uv_attrs = gloo::transport::uv::attr(); + if (hostname) { + uv_attrs.hostname = *hostname; + } + if (interface) { + uv_attrs.iface = *interface; + } + auto uv_device = gloo::transport::uv::CreateDevice(uv_attrs); + return std::make_shared(std::move(gloo_kv_store), + std::move(uv_device)); +#else // defined(__linux__) + throw xla::XlaRuntimeError( + "make_gloo_tcp_collectives only implemented for linux and macos"); +#endif // defined(__linux__) + }, + nb::arg("distributed_client"), nb::arg("hostname").none() = std::nullopt, + nb::arg("interface").none() = std::nullopt); + +#if !defined(_WIN32) && !defined(PLATFORM_GOOGLE) + nb::class_ mpi_collectives(m, "MpiCollectives", + cpu_collectives); + mpi_collectives.def("Init", &cpu::MpiCollectives::Init); + mpi_collectives.def("Finalize", &cpu::MpiCollectives::Finalize); + m.def("make_mpi_collectives", []() -> std::shared_ptr { + return std::make_shared(); + }); +#else // !_WIN32 && !PLATFORM_GOOGLE + m.def("make_mpi_collectives", + []() -> std::shared_ptr { + throw xla::XlaRuntimeError( + "make_mpi_collectives is not implemented for Windows"); + }); +#endif // !_WIN32 && !PLATFORM_GOOGLE + + m.def( + "get_tfrt_cpu_client", + [](bool asynchronous, + std::shared_ptr distributed_client, + int node_id, int num_nodes, + std::shared_ptr collectives, + std::optional num_devices) -> nb_class_ptr { + std::unique_ptr ifrt_client; + { + nb::gil_scoped_release gil_release; + xla::CpuClientOptions options; + + options.asynchronous = asynchronous; + options.collectives = std::move(collectives); + options.process_id = node_id; + options.cpu_device_count = num_devices; + std::unique_ptr client = + xla::ValueOrThrow(xla::GetXlaPjrtCpuClient(std::move(options))); + ifrt::PjRtClient::CreateOptions ifrt_options; + ifrt_options.pjrt_client = + std::shared_ptr(std::move(client)); + if (distributed_client != nullptr) { + ifrt_options.kv_store = + GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"cpu:"); + ifrt_options.process_id = node_id; + ifrt_options.num_processes = num_nodes; + } + ifrt_client = + ValueOrThrow(ifrt::PjRtClient::Create(std::move(ifrt_options))); + } + return PyClient::Make(std::move(ifrt_client)); + }, + nb::arg("asynchronous") = true, nb::arg("distributed_client") = nullptr, + nb::arg("node_id") = 0, nb::arg("num_nodes") = 1, + nb::arg("collectives").none() = + std::shared_ptr(), + nb::arg("num_devices").none() = std::nullopt); + m.def("pjrt_plugin_loaded", [](std::string platform_name) -> bool { + absl::StatusOr pjrt_api = pjrt::PjrtApi(platform_name); + return pjrt_api.ok(); + }); + m.def( + "load_pjrt_plugin", + [](std::string platform_name, std::optional library_path, + std::optional c_api) -> nb::capsule { + if (library_path.has_value()) { + const PJRT_Api *api = xla::ValueOrThrow( + pjrt::LoadPjrtPlugin(platform_name, *library_path)); + return nb::capsule(absl::bit_cast(api), "pjrt_c_api"); + } + if (absl::string_view(c_api->name()) != "pjrt_c_api") { + throw nb::value_error( + "c_api argument to load_pjrt_plugin is not a pjrt_c_api " + "capsule."); + } + xla::ThrowIfError(pjrt::SetPjrtApi( + platform_name, static_cast(c_api->data()))); + return *c_api; + }, + nb::arg("platform_name"), nb::arg("library_path").none() = std::nullopt, + nb::arg("c_api").none() = std::nullopt); + m.def("pjrt_plugin_initialized", [](std::string platform_name) -> bool { + return xla::ValueOrThrow(pjrt::IsPjrtPluginInitialized(platform_name)); + }); + m.def("initialize_pjrt_plugin", [](std::string platform_name) { + return xla::ThrowIfError(pjrt::InitializePjrtPlugin(platform_name)); + }); + + m.def( + "get_c_api_client", + [](std::string platform_name, + const absl::flat_hash_map &options, + std::shared_ptr distributed_client) + -> nb_class_ptr { + std::unique_ptr ifrt_client; + { + nb::gil_scoped_release gil_release; + std::shared_ptr kv_store = nullptr; + if (distributed_client != nullptr) { + kv_store = GetDistributedKeyValueStore( + distributed_client, + /*key_prefix=*/absl::StrCat(platform_name, ":")); + } + std::unique_ptr c_api_client = xla::ValueOrThrow( + GetCApiClient(platform_name, options, kv_store)); + ifrt_client = ifrt::PjRtClient::Create(std::move(c_api_client)); + } + return PyClient::Make(std::move(ifrt_client)); + }, + nb::arg("platform_name"), + nb::arg("options") = absl::flat_hash_map(), + nb::arg("distributed_client").none() = nullptr); + // TODO(b/322357665): Delete this method after TPU plugin changes to use the + // standard registration. + m.def("get_default_c_api_topology", + [](std::string platform_name, std::string topology_name, + const absl::flat_hash_map &options) + -> std::shared_ptr { + return std::make_shared(xla::ValueOrThrow( + GetCApiTopology(platform_name, topology_name, options))); + }); + m.def("get_c_api_topology", + [](nb::capsule c_api, std::string topology_name, + const absl::flat_hash_map &options) + -> std::shared_ptr { + if (absl::string_view(c_api.name()) != "pjrt_c_api") { + throw nb::value_error( + "Argument to get_c_api_topology was not a pjrt_c_api capsule."); + } + return std::make_shared(xla::ValueOrThrow( + GetCApiTopology(static_cast(c_api.data()), + topology_name, options))); + }); + m.def("get_topology_for_devices", + [](const std::vector> &py_devices) { + if (py_devices.empty()) { + throw nb::value_error( + "get_topology_for_devices requires >= 1 devices."); + } + auto client = py_devices[0]->client(); + absl::InlinedVector ifrt_devices; + ifrt_devices.reserve(py_devices.size()); + for (const auto &py_device : py_devices) { + if (py_device->client().get() != client.get()) { + throw nb::value_error( + "devices passed to get_topology_for_devices come from " + "different clients."); + } + ifrt_devices.push_back(py_device->device()); + } + ifrt::DeviceListRef device_list = + client->ifrt_client()->MakeDeviceList(ifrt_devices); + return xla::ValueOrThrow( + client->ifrt_client()->GetTopologyForDevices(device_list)); + }); + + TF_CHECK_OK(PyArray::RegisterTypes(m)); + jax::PyDeviceList::Register(m); + jax::RegisterSharding(m); + + nb::class_(m, "CompiledMemoryStats") + .def_rw("generated_code_size_in_bytes", + &CompiledMemoryStats::generated_code_size_in_bytes) + .def_rw("argument_size_in_bytes", + &CompiledMemoryStats::argument_size_in_bytes) + .def_rw("output_size_in_bytes", + &CompiledMemoryStats::output_size_in_bytes) + .def_rw("alias_size_in_bytes", &CompiledMemoryStats::alias_size_in_bytes) + .def_rw("temp_size_in_bytes", &CompiledMemoryStats::temp_size_in_bytes) + .def_rw("host_generated_code_size_in_bytes", + &CompiledMemoryStats::host_generated_code_size_in_bytes) + .def_rw("host_argument_size_in_bytes", + &CompiledMemoryStats::host_argument_size_in_bytes) + .def_rw("host_output_size_in_bytes", + &CompiledMemoryStats::host_output_size_in_bytes) + .def_rw("host_alias_size_in_bytes", + &CompiledMemoryStats::host_alias_size_in_bytes) + .def_rw("host_temp_size_in_bytes", + &CompiledMemoryStats::host_temp_size_in_bytes) + .def_prop_ro("serialized_buffer_assignment_proto", + [](const CompiledMemoryStats &cms) -> nb::bytes { +#if JAX_IFRT_VERSION_NUMBER >= 7 + if (cms.buffer_assignment.has_value()) { + std::string s = + cms.buffer_assignment->SerializeAsString(); + return nb::bytes(s.data(), s.size()); + } else { + return nb::bytes(); + } +#else + xla::HloProto hlo; + if (!cms.serialized_hlo_proto.empty() && + hlo.ParseFromString(cms.serialized_hlo_proto)) { + std::string s = + hlo.buffer_assignment().SerializeAsString(); + return nb::bytes(s.data(), s.size()); + } + return nb::bytes(); +#endif + }) + .def("__str__", &CompiledMemoryStats::DebugString); + + nb::class_(m, "ExecuteResults") + .def("__len__", [](PyExecuteResults &results) { return results.Size(); }) + .def("disassemble_into_single_device_arrays", + &PyExecuteResults::DisassembleIntoSingleDeviceArrays) + .def("disassemble_prefix_into_single_device_arrays", + &PyExecuteResults::DisassemblePrefixIntoSingleDeviceArrays) + .def("consume_with_handlers", &PyExecuteResults::ConsumeWithHandlers) + .def("consume_token", &PyExecuteResults::ConsumeToken); + + nb::class_(m, "LoadedExecutable") + .def_prop_ro("client", &PyLoadedExecutable::client) + .def("local_devices", &PyLoadedExecutable::AddressableDevices) + .def("size_of_generated_code_in_bytes", + &PyLoadedExecutable::SizeOfGeneratedCodeInBytes) + .def( + "get_compiled_memory_stats", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetCompiledMemoryStats)) + .def("execute_sharded", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::ExecuteSharded), + nb::arg("arguments"), nb::arg("with_tokens") = false) + .def("hlo_modules", ValueOrThrowWrapper(&PyLoadedExecutable::HloModules)) + .def("get_output_memory_kinds", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetOutputMemoryKinds)) + .def("get_output_shardings", &PyLoadedExecutable::GetOutputShardings) + .def("get_parameter_layouts", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetParameterLayouts)) + .def("get_output_layouts", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetOutputLayouts)) + .def("get_parameter_shardings", + &PyLoadedExecutable::GetParameterShardings) + .def("keep_alive", &PyLoadedExecutable::KeepAlive) + .def("cost_analysis", + [](const PyLoadedExecutable &self) { + auto map = ValueOrThrow(self.GetCostAnalysis()); + return ifrt::ToPjRtAttributeMap(std::move(map)); + }) + .def_prop_ro("traceback", &PyLoadedExecutable::traceback) + .def_prop_ro("fingerprint", [](PyLoadedExecutable *exec) -> nb::object { + if (exec->fingerprint().has_value()) { + return nb::bytes(exec->fingerprint()->data(), + exec->fingerprint()->size()); + } else { + return nb::none(); + } + }); + nb::class_ token(m, "Token"); + token.def("block_until_ready", + [](PyToken &self) { xla::ThrowIfError(self.Await()); }); + + nb::class_ sharded_token(m, "ShardedToken"); + sharded_token.def("block_until_ready", [](PyShardedToken &self) { + xla::ThrowIfError(self.Await()); + }); + sharded_token.def("get_token", &PyShardedToken::GetPyToken); + + m.def("buffer_to_dlpack_managed_tensor", + xla::ValueOrThrowWrapper(BufferToDLPackManagedTensor), + nb::arg("buffer"), nb::arg("stream").none() = nb::none()); + m.def( + "dlpack_managed_tensor_to_buffer", + [](const nb::capsule &tensor, nb_class_ptr device, + std::optional stream) { + return xla::ValueOrThrow(DLPackManagedTensorToBuffer( + tensor, device->device(), device->client(), stream)); + }, + nb::arg("dlpack"), nb::arg("device"), nb::arg("stream").none()); + // Legacy overload + m.def( + "dlpack_managed_tensor_to_buffer", + [](const nb::capsule &tensor, + std::optional> cpu_client, + std::optional> gpu_client) { + return xla::ValueOrThrow(DLPackManagedTensorToBuffer( + tensor, std::move(cpu_client), std::move(gpu_client))); + }, + nb::arg("dlpack"), nb::arg("cpu_backend").none() = nb::none(), + nb::arg("gpu_backend").none() = nb::none()); + m.def("cuda_array_interface_to_buffer", + xla::ValueOrThrowWrapper(CudaArrayInterfaceToBuffer), nb::arg("cai"), + nb::arg("gpu_backend").none() = nb::none(), + nb::arg("device_id").none() = nb::none()); + + jax::BuildConfigSubmodule(m); + BuildIfrtProgramsSubmodule(m); + BuildPytreeSubmodule(m); + jax::BuildGuardSubmodule(m); + jax::BuildJaxjitSubmodule(m); + jax::BuildPmapSubmodule(m); + jax::BuildPjitSubmodule(m); + BuildTracebackSubmodule(m); + BuildMlirSubmodule(m); + BuildSdySubmodule(m); + BuildCustomCallShardingPybindAPI(m); + jax::BuildFfiSubmodule(m); +#if defined(__linux__) + aux::RegisterTransferServerTypes(m); +#endif // defined(__linux__) + + // The following uses python bindings for PyClient defined above using + // pybind11, and hence needs pybind11::module_ (not just nanobind::module_). + xla::ifrt::proxy::BuildIfrtProxySubmodule(m); + + nb::class_ preemption_sync_manager( + m, "PreemptionSyncManager"); + preemption_sync_manager + .def( + "initialize", + [](tsl::PreemptionSyncManager &manager, + DistributedRuntimeClient *client) { + tsl::CoordinationServiceAgent *agent = + xla::ValueOrThrow(client->GetCoordinationServiceAgent()); + xla::ThrowIfError(manager.Initialize(agent)); + }, + nb::arg("distributed_client")) + .def("reached_sync_point", + [](tsl::PreemptionSyncManager &manager, int step_counter) { + return manager.ReachedSyncPoint(step_counter); + }) + .def("shutdown", [](tsl::PreemptionSyncManager &manager) { + nb::gil_scoped_release gil_release; + manager.Shutdown(); + }); + m.def("create_preemption_sync_manager", + []() { return tsl::CreatePreemptionSyncManager(); }); + + nb::class_ distributed_runtime_service( + m, "DistributedRuntimeService"); + distributed_runtime_service.def("shutdown", + &DistributedRuntimeService::Shutdown, + nb::call_guard()); + nb::class_ distributed_runtime_client( + m, "DistributedRuntimeClient"); + distributed_runtime_client + .def("connect", + [](DistributedRuntimeClient &self) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(self.Connect()); + }) + .def("shutdown", + [](DistributedRuntimeClient &self) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(self.Shutdown()); + }) + // This method assumes that the value is a Python string. Use + // `blocking_key_value_get_bytes()` if key_value_set() was called with a + // Python bytes object as its value. + .def( + "blocking_key_value_get", + [](DistributedRuntimeClient &client, std::string key, + int64_t timeout_in_ms) { + nb::gil_scoped_release gil_release; + return xla::ValueOrThrow(client.BlockingKeyValueGet( + key, absl::Milliseconds(timeout_in_ms))); + }, + nb::arg("key"), nb::arg("timeout_in_ms")) + // Same as `blocking_key_value_get()`, but retrieves the raw Python byte + // values explicitly. + .def( + "blocking_key_value_get_bytes", + [](DistributedRuntimeClient &client, std::string key, + int64_t timeout_in_ms) -> nb::bytes { + std::string result; + { + nb::gil_scoped_release gil_release; + result = xla::ValueOrThrow(client.BlockingKeyValueGet( + key, absl::Milliseconds(timeout_in_ms))); + } + return nb::bytes(result.data(), result.size()); + }, + nb::arg("key"), nb::arg("timeout_in_ms")) + .def( + "key_value_try_get", + [](DistributedRuntimeClient &client, std::string key) { + nb::gil_scoped_release gil_release; + return xla::ValueOrThrow(client.KeyValueTryGet(key)); + }, + nb::arg("key")) + .def( + "key_value_try_get_bytes", + [](DistributedRuntimeClient &client, std::string key) -> nb::bytes { + std::string result; + { + nb::gil_scoped_release gil_release; + result = xla::ValueOrThrow(client.KeyValueTryGet(key)); + } + return nb::bytes(result.data(), result.size()); + }, + nb::arg("key")) + .def( + "wait_at_barrier", + [](DistributedRuntimeClient &client, std::string barrier_id, + int64_t timeout_in_ms, + std::optional> process_ids) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(client.WaitAtBarrier( + barrier_id, absl::Milliseconds(timeout_in_ms), process_ids)); + }, + nb::arg("barrier_id"), nb::arg("timeout_in_ms"), + nb::arg("process_ids") = std::nullopt) + .def( + "get_live_nodes", + [](DistributedRuntimeClient &client, + std::vector process_ids) { + nb::gil_scoped_release gil_release; + return xla::ValueOrThrow(client.GetLiveNodes(process_ids)); + }, + nb::arg("process_ids")) + // The key must be a string, but the value can either be a Python string + // or bytes object. + // With Python string values, use `key_value_set()` and + // `blocking_key_value_get()`. + // With Python byte object values, use `key_value_set()` and + // `blocking_key_value_get_bytes()`. + .def( + "key_value_set", + [](DistributedRuntimeClient &client, absl::string_view key, + absl::string_view value, bool allow_overwrite) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(client.KeyValueSet(key, value, allow_overwrite)); + }, + nb::arg("key"), nb::arg("value"), nb::arg("allow_overwrite") = false) + // The key must be a string, but the value must a + // Python bytes object. + // Use `key_value_set_bytes()` and `blocking_key_value_get_bytes()`. + .def( + "key_value_set_bytes", + [](DistributedRuntimeClient &client, absl::string_view key, + nb::bytes value, bool allow_overwrite) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(client.KeyValueSet( + key, absl::string_view(value.c_str(), value.size()), + allow_overwrite)); + }, + nb::arg("key"), nb::arg("value"), nb::arg("allow_overwrite") = false) + // Assumes that all values in the directory are Python strings. + .def( + "key_value_dir_get", + [](DistributedRuntimeClient &client, absl::string_view key) { + nb::gil_scoped_release gil_release; + return xla::ValueOrThrow(client.KeyValueDirGet(key)); + }, + nb::arg("key")) + // Assumes that all values in the directory are Python byte objects. + // Same as `key_value_dir_get()`, but retrieves Python byte values + // explicitly. + .def( + "key_value_dir_get_bytes", + [](DistributedRuntimeClient &client, absl::string_view key) + -> std::vector> { + std::vector> result; + { + nb::gil_scoped_release gil_release; + result = xla::ValueOrThrow(client.KeyValueDirGet(key)); + } + // Convert std::string values to nb::bytes. + std::vector> kvs; + kvs.reserve(result.size()); + for (auto &kv : result) { + kvs.push_back( + std::pair(std::move(kv.first), + nb::bytes(kv.second.data(), kv.second.size()))); + } + return kvs; + }, + nb::arg("key")) + .def( + "key_value_delete", + [](DistributedRuntimeClient &client, absl::string_view key) { + nb::gil_scoped_release gil_release; + return xla::ThrowIfError(client.KeyValueDelete(key)); + }, + nb::arg("key")); + + m.def( + "get_distributed_runtime_service", + [](std::string address, int num_nodes, + std::optional heartbeat_interval, + std::optional max_missing_heartbeats, + std::optional cluster_register_timeout, + std::optional shutdown_timeout) + -> std::unique_ptr { + CoordinationServiceImpl::Options options; + options.num_nodes = num_nodes; + if (heartbeat_interval.has_value()) { + options.heartbeat_interval = absl::Seconds(*heartbeat_interval); + } + if (max_missing_heartbeats.has_value()) { + options.max_missing_heartbeats = *max_missing_heartbeats; + } + if (cluster_register_timeout.has_value()) { + options.cluster_register_timeout = + absl::Seconds(*cluster_register_timeout); + } + if (shutdown_timeout.has_value()) { + options.shutdown_timeout = absl::Seconds(*shutdown_timeout); + } + std::unique_ptr service = + xla::ValueOrThrow(GetDistributedRuntimeService(address, options)); + return service; + }, + nb::arg("address"), nb::arg("num_nodes"), + nb::arg("heartbeat_interval").none() = std::nullopt, + nb::arg("max_missing_heartbeats").none() = std::nullopt, + nb::arg("cluster_register_timeout").none() = std::nullopt, + nb::arg("shutdown_timeout").none() = std::nullopt); + + m.def( + "get_distributed_runtime_client", + [](std::string address, int node_id, std::optional rpc_timeout, + std::optional init_timeout, std::optional shutdown_timeout, + std::optional heartbeat_interval, + std::optional max_missing_heartbeats, + std::optional> + missed_heartbeat_callback, + std::optional shutdown_on_destruction, + std::optional use_compression) + -> std::shared_ptr { + bool compression = use_compression.value_or(false); + DistributedRuntimeClient::Options options; + options.node_id = node_id; + if (rpc_timeout.has_value()) { + options.rpc_timeout = absl::Seconds(*rpc_timeout); + } + if (init_timeout.has_value()) { + options.init_timeout = absl::Seconds(*init_timeout); + } + if (shutdown_timeout.has_value()) { + options.shutdown_timeout = absl::Seconds(*shutdown_timeout); + } + if (heartbeat_interval.has_value()) { + options.heartbeat_interval = absl::Seconds(*heartbeat_interval); + } + if (max_missing_heartbeats.has_value()) { + options.max_missing_heartbeats = *max_missing_heartbeats; + } + if (missed_heartbeat_callback.has_value()) { + options.missed_heartbeat_callback = + std::move(*missed_heartbeat_callback); + } + if (shutdown_on_destruction.has_value()) { + options.shutdown_on_destruction = *shutdown_on_destruction; + } + return GetDistributedRuntimeClient(address, options, compression); + }, + nb::arg("address"), nb::arg("node_id"), + nb::arg("rpc_timeout").none() = std::nullopt, + nb::arg("init_timeout").none() = std::nullopt, + nb::arg("shutdown_timeout").none() = std::nullopt, + nb::arg("heartbeat_interval").none() = std::nullopt, + nb::arg("max_missing_heartbeats").none() = std::nullopt, + nb::arg("missed_heartbeat_callback").none() = std::nullopt, + nb::arg("shutdown_on_destruction").none() = std::nullopt, + nb::arg("use_compression").none() = std::nullopt); + + m.def("collect_garbage", []() { GlobalPyRefManager()->CollectGarbage(); }); + + m.def("is_optimized_build", &IsOptimizedBuild); + + m.def("json_to_pprof_profile", xla::ValueOrThrowWrapper(JsonToPprofProfile), + "Encodes the JSON representation of a pprof Profile into its binary " + "protocol buffer encoding."); + m.def("pprof_profile_to_json", xla::ValueOrThrowWrapper(PprofProfileToJson), + "Decodes an uncompressed pprof Profile protocol buffer into a JSON " + "representation"); + + RegisterCompileOnlyClient(m); + nb::class_(m, "DeviceTopology") + .def("_make_compile_only_devices", + [](std::shared_ptr topology) { + if (!llvm::isa(*topology)) { + throw xla::XlaRuntimeError("Only PjRtTopologies are supported."); + } + return MakeCompileOnlyClient( + std::dynamic_pointer_cast(topology)) + ->Devices(); + }) + .def_prop_ro( + "platform", + [](ifrt::Topology &topology) { return topology.platform_name(); }) + .def_prop_ro( + "platform_version", + [](ifrt::Topology &topology) { return topology.platform_version(); }) + .def("serialize", + [](ifrt::Topology &topology) -> nb::bytes { + std::string serialized = ValueOrThrow(topology.Serialize()); + return nb::bytes(serialized.data(), serialized.size()); + }) + .def("__getattr__", + [](ifrt::Topology &topology, absl::string_view name) -> nb::object { + const auto &attrs = topology.Attributes().map(); + auto it = attrs.find(name); + if (it != attrs.end()) { + return std::visit([](auto &&v) { return nb::cast(v.value); }, + it->second); + } + throw nb::attribute_error( + absl::StrCat("Unknown attribute ", name).c_str()); + }); + + nb::class_(m, "Executable") + .def("hlo_modules", ValueOrThrowWrapper(&ifrt::Executable::GetHloModules)) + .def("get_output_memory_kinds", + xla::ValueOrThrowWrapper(&ifrt::Executable::GetOutputMemoryKinds)) + .def("get_output_shardings", &ifrt::Executable::GetOutputShardings) + .def("get_parameter_layouts", + ValueOrThrowWrapper(&ifrt::Executable::GetParameterLayouts)) + .def("get_output_layouts", + xla::ValueOrThrowWrapper(&ifrt::Executable::GetOutputLayouts)) + .def("get_parameter_shardings", &ifrt::Executable::GetParameterShardings) + .def("get_compiled_memory_stats", + xla::ValueOrThrowWrapper(&ifrt::Executable::GetCompiledMemoryStats)) + .def("serialize", + [](const ifrt::Executable &exec) -> nb::bytes { + std::string serialized = ValueOrThrow(exec.Serialize()); + return nb::bytes(serialized.data(), serialized.size()); + }) + .def("cost_analysis", [](const ifrt::Executable &exec) { + auto attrs = ValueOrThrow(exec.GetCostAnalysis()); + return ifrt::ToPjRtAttributeMap(std::move(attrs)); + }); + + m.def("is_asan", IsAsan); + m.def("is_msan", IsMsan); + m.def("is_tsan", IsTsan); + m.def("is_sanitized", IsSanitized); + + m.def( + "batched_device_put", + [](nb::object aval, nb::object sharding, std::vector xs, + std::vector dst_devices, bool committed, + bool force_copy, + PjRtClient::HostBufferSemantics host_buffer_semantics) -> nb::object { + return ValueOrThrow(PyArray::BatchedDevicePut( + aval, sharding, std::move(xs), std::move(dst_devices), committed, + force_copy, host_buffer_semantics, jax::GetEnableX64())); + }, + nb::arg("aval"), nb::arg("sharding"), nb::arg("xs"), nb::arg("devices"), + nb::arg("committed") = true, nb::arg("force_copy") = false, + nb::arg("host_buffer_semantics") = + PjRtClient::HostBufferSemantics::kImmutableZeroCopy); + m.def( + "reorder_shards", + [](PyArray x, nb::object dst_sharding, + ifrt::ArrayCopySemantics array_copy_semantics) { + return ValueOrThrow(PyArray::ReorderShards( + std::move(x), std::move(dst_sharding), array_copy_semantics)); + }, + nb::arg("x"), nb::arg("dst_sharding"), nb::arg("array_copy_semantics")); + + m.def("batched_block_until_ready", [](std::vector xs) { + ThrowIfError(PyArray::BatchedBlockUntilReady(std::move(xs))); + }); + + m.def("check_and_canonicalize_memory_kind", + &jax::CheckAndCanonicalizeMemoryKind, nb::arg("memory_kind").none(), + nb::arg("device_list")); + + m.attr("ifrt_version_number") = JAX_IFRT_VERSION_NUMBER; + + m.def("approx_top_k_reduction_output_size", + xla::ValueOrThrowWrapper(ApproxTopKReductionOutputSize), + nb::arg("input_size"), nb::arg("rank"), nb::arg("top_k"), + nb::arg("recall_target"), nb::arg("aggregate_to_topk") = true, + nb::arg("input_size_override") = -1); + + m.def("get_internal_device_put_info", + []() { return DevicePutInfo::GetInfo(); }); + +} // NOLINT(readability/fn_size) + +} // namespace xla diff --git a/tests/ci_clangformat/xla_compiler.cc b/tests/ci_clangformat/xla_compiler.cc new file mode 100644 index 0000000..8e5389f --- /dev/null +++ b/tests/ci_clangformat/xla_compiler.cc @@ -0,0 +1,1451 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla_compiler.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "jaxlib/dlpack.h" +#include "jaxlib/py_client.h" +#include "nanobind/nanobind.h" +#include "nanobind/ndarray.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/array.h" +#include "xla/client/executable_build_options.h" +#include "xla/debug_options_flags.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/ffi.h" +#include "xla/ffi/ffi_api.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_print_options.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/literal.h" +#include "xla/pjrt/compile_options.pb.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/nb_numpy.h" +#include "xla/python/types.h" +#include "xla/service/computation_placer.h" +#include "xla/service/custom_call_target_registry.h" +#include "xla/service/hlo.pb.h" +#include "xla/service/hlo_graph_dumper.h" +#include "xla/service/hlo_module_config.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/lib/strings/proto_serialization.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace { + +namespace nb = nanobind; + +// Converts a computation to a serialized HloModuleProto. +absl::StatusOr GetComputationSerializedProto( + const XlaComputation &computation) { + std::string result; + if (!tsl::SerializeToStringDeterministic(computation.proto(), &result)) { + return Unknown("Failed to serialize the HloModuleProto."); + } + return nb::bytes(result.data(), result.size()); +} + +// Converts a hlo module to a serialized HloModuleProto. +absl::StatusOr GetHloModuleSerializedProto(const HloModule &module) { + std::string result; + if (!tsl::SerializeToStringDeterministic(module.ToProto(), &result)) { + return Unknown("Failed to serialize the HloModuleProto."); + } + return nb::bytes(result.data(), result.size()); +} + +// Converts a serialized HloModuleProto into a HloModule. +absl::StatusOr> HloModuleFromSerializedProto( + const nb::bytes &bytes) { + HloModuleProto proto; + proto.ParseFromArray(bytes.c_str(), bytes.size()); + TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, + HloModule::CreateModuleConfigFromProto( + proto, GetDebugOptionsFromFlags())); + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + HloModule::CreateFromProto(proto, module_config)); + return std::shared_ptr(std::move(module)); +} + +absl::StatusOr> GetHloModule( + const XlaComputation &computation) { + TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, + HloModule::CreateModuleConfigFromProto( + computation.proto(), GetDebugOptionsFromFlags())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module, + HloModule::CreateFromProto(computation.proto(), module_config)); + return std::shared_ptr(std::move(module)); +} + +// Converts a computation to textual HLO form. +absl::StatusOr GetComputationHloText( + const XlaComputation &computation, bool print_large_constants = false) { + TF_ASSIGN_OR_RETURN(std::shared_ptr hlo_module, + GetHloModule(computation)); + HloPrintOptions options; + options = HloPrintOptions::ShortParsable(); + options.set_print_large_constants(print_large_constants); + return hlo_module->ToString(options); +} + +// Converts a computation to HLO dot graph form. +absl::StatusOr GetComputationHloDotGraph( + const XlaComputation &computation) { + TF_ASSIGN_OR_RETURN(std::shared_ptr hlo_module, + GetHloModule(computation)); + return RenderGraph(*hlo_module->entry_computation(), /*label=*/"", + hlo_module->config().debug_options(), + RenderedGraphFormat::kDot); +} + +// Hashes the HLO module. +absl::StatusOr HashComputation(const XlaComputation &computation) { + TF_ASSIGN_OR_RETURN(std::shared_ptr hlo_module, + GetHloModule(computation)); + return absl::HashOf(*hlo_module); +} +// Safe version of ShapeUtil::MakeShapeWithDenseLayout that fails gracefully on +// invalid input. +absl::StatusOr MakeShapeWithDenseLayout( + PrimitiveType element_type, absl::Span dims, + std::optional> minor_to_major, + std::optional> dynamic_dimensions) { + Shape shape; + if (dynamic_dimensions) { + TF_ASSIGN_OR_RETURN( + shape, ShapeUtil::MakeValidatedShape(element_type, dims, + dynamic_dimensions.value())); + } else { + TF_ASSIGN_OR_RETURN(shape, + ShapeUtil::MakeValidatedShape(element_type, dims)); + } + if (minor_to_major) { + *shape.mutable_layout() = LayoutUtil::MakeLayout(*minor_to_major); + TF_RETURN_IF_ERROR( + LayoutUtil::ValidateLayoutForShape(shape.layout(), shape)); + } + + return shape; +} + +// Pybind function for HloSharding.iota_tile, which is a non-crashing factory +// that produces a HloSharding instance backed by tile assignment of a +// transposed and reshaped iota array of device ids. More specifically the tile +// assignment array is as if it is produced by the following numpy code: +// numpy.arange(math.prod(dims)).reshape(reshape_dims) +// .transpose(transpose_perm).reshape(math.prod(dims)) +// where: +// `dims`: is the dimensions of the tile assignment array, which corresponds to +// OpSharding.tile_assignment_dimensions. +// `reshape_dims`: is the dimensions the 1D iota array is reshaped to. +// `transpose_perm`: is the dimension permutation to transpose `reshape_dims`. +// `subgroup_types`: indicates the subgroups of the last `subgroup_types.size()` +// dimensions in `dims`. +// +// In practice, `reshape_dims` often maps to the axises of user defined device +// mesh, and `transpose_perm` often maps to the user specification of how a +// tensor is partitioned based on the axes defined in the mesh, e.g. for a mesh +// of size 4x2x2 as AxBxC: +// PartitionSpec('A', 'B', 'C') corresponds to reshape_dims=[4,2,2], +// transpose_perm=[0,1,2] (no transpose) +// PartitionSpec('B', 'A', 'C') corresponds to reshape_dims=[4,2,2], +// transpose_perm=[1,0,2] (swap A and B) +absl::StatusOr IotaTileHelper( + absl::Span dims, absl::Span reshape_dims, + absl::Span transpose_perm, + absl::Span subgroup_types) { + if (dims.empty()) { + return InvalidArgument("`dims` should not be empty."); + } + if (reshape_dims.size() != transpose_perm.size()) { + return InvalidArgument( + "`reshape_dims` and `transpose_perm` should have the same size, saw " + "[%s] v.s. [%s]", + absl::StrJoin(reshape_dims, ","), absl::StrJoin(transpose_perm, ",")); + } + if (!reshape_dims.empty() && Product(dims) != Product(reshape_dims)) { + return InvalidArgument( + "Cannot reshape from `dims` [%s] to `reshape_dims` [%s].", + absl::StrJoin(dims, ","), absl::StrJoin(reshape_dims, ",")); + } + if (subgroup_types.size() > dims.size()) { + return InvalidArgument( + "`subgroup_types`(%lld) should not have more dimensions than " + "`dims`(%lld).", + subgroup_types.size(), dims.size()); + } + if (reshape_dims.empty()) { + return subgroup_types.empty() + ? HloSharding::IotaTile(dims) + : HloSharding::Subgroup(TileAssignment(dims), subgroup_types); + } + return subgroup_types.empty() + ? HloSharding::IotaTile(dims, reshape_dims, transpose_perm) + : HloSharding::Subgroup( + TileAssignment(dims, reshape_dims, transpose_perm), + subgroup_types); +} + +// Registers a 'fn' as a custom call target. +// +// `fn` must be a custom call implementation function pointer (XLA_FFI_Handler* +// when implemented as FFI handler) encapsulated in a PyCapsule object or a +// a dictionary of function pointers (also encapsulated in a PyCapsule). +// +// See XLA_FFI_ExecutionStage documentation for more details about the +// custom execution stages. +absl::Status PyRegisterCustomCallTarget(const std::string &fn_name, + nb::object fn, + const std::string &platform, + int api_version, + XLA_FFI_Handler_Traits traits) { + // Register legacy custom call target (untyped void* API). + if (api_version == 0) { + if (traits != 0) { + return absl::InvalidArgumentError( + "Custom call target registration with traits is not supported for " + "api_version=0"); + } + + nb::capsule capsule; + if (!nb::try_cast(fn, capsule)) { + return absl::InvalidArgumentError( + "Custom call target registration with api_version=0 requires a " + "PyCapsule fn object"); + } + + CustomCallTargetRegistry::Global()->Register( + fn_name, static_cast(capsule.data()), platform); + return absl::OkStatus(); + } + + // Register XLA FFI handler (typed API with explicit function signatures). + if (api_version == 1) { + nb::capsule capsule; + if (nb::try_cast(fn, capsule)) { + return ffi::TakeStatus(ffi::Ffi::RegisterStaticHandler( + xla::ffi::GetXlaFfiApi(), fn_name, platform, + reinterpret_cast( + static_cast(capsule.data())))); + } + + nb::dict bundle; + if (nb::try_cast(fn, bundle)) { + auto handler = + [&](const char *name) -> absl::StatusOr { + if (!bundle.contains(name)) return nullptr; + + nb::capsule capsule; + if (!nb::try_cast(bundle[name], capsule)) { + return absl::InvalidArgumentError( + "Custom call target registration with api_version=1 requires a " + "PyCapsule fn object for all dict keys"); + } + + return reinterpret_cast(capsule.data()); + }; + + XLA_FFI_Handler_Bundle bundle; + TF_ASSIGN_OR_RETURN(bundle.instantiate, handler("instantiate")); + TF_ASSIGN_OR_RETURN(bundle.prepare, handler("prepare")); + TF_ASSIGN_OR_RETURN(bundle.initialize, handler("initialize")); + TF_ASSIGN_OR_RETURN(bundle.execute, handler("execute")); + + return ffi::TakeStatus(ffi::Ffi::RegisterStaticHandler( + xla::ffi::GetXlaFfiApi(), fn_name, platform, bundle, traits)); + } + + return absl::InvalidArgumentError( + "Unsupported custom call target type for api_version=1"); + } + + return absl::UnimplementedError(absl::StrFormat( + "API version %d is not supported by RegisterCustomCallTarget. " + "Supported versions are 0 and 1.", + api_version)); +} + +absl::Status PyRegisterCustomTypeId(absl::string_view type_name, + nb::object type_id) { + nb::capsule capsule; + if (!nb::try_cast(type_id, capsule)) { + return absl::InvalidArgumentError( + "The type_id argument to register_custom_call_type_id must be a " + "PyCapsule object holding a pointer to a XLA_FFI_TypeId."); + } + XLA_FFI_TypeId *type_id_ptr = + reinterpret_cast(static_cast(capsule.data())); + return ffi::TakeStatus(ffi::Ffi::RegisterTypeId(xla::ffi::GetXlaFfiApi(), + type_name, type_id_ptr)); +} + +template +void DefRepeatedProperty(nb::class_ &cls, const char *name, + Container *(T::*getter)()) { + cls.def_prop_rw( + name, + [getter](T &obj) { + Container *elems = (obj.*getter)(); + std::vector result; + result.reserve(elems->size()); + std::copy(elems->begin(), elems->end(), std::back_inserter(result)); + return result; + }, + [getter](T &obj, std::vector new_elems) { + Container *elems = (obj.*getter)(); + elems->Clear(); + elems->Reserve(new_elems.size()); + for (typename Container::value_type &e : new_elems) { + elems->Add(std::move(e)); + } + }); +} + +template +void DefRepeatedEnumProperty(nb::class_ &cls, const char *name, + Container *(T::*getter)()) { + cls.def_prop_rw( + name, + [getter](T &obj) { + Container *elems = (obj.*getter)(); + std::vector result; + result.reserve(elems->size()); + std::copy(elems->begin(), elems->end(), std::back_inserter(result)); + return result; + }, + [getter](T &obj, nb::sequence new_elems) { + Container *elems = (obj.*getter)(); + elems->Clear(); + for (nb::handle e : new_elems) { + elems->Add(nb::cast(e.attr("value"))); + } + }); +} + +template +Array NDArrayToArray(nb::ndarray ndarray) { + std::vector shapes; + shapes.reserve(ndarray.ndim()); + for (int i = 0; i < ndarray.ndim(); ++i) { + shapes.push_back(ndarray.shape(i)); + } + xla::Array array(shapes); + array.Each([&](absl::Span indices, int64_t *val) { + int64_t offset = indices.back(); + int64_t multiplier = 1; + for (int i = ndarray.ndim() - 1; i > 0; --i) { + multiplier *= ndarray.shape(i); + offset += indices[i - 1] * multiplier; + } + *val = *(ndarray.data() + offset); + }); + return array; +} + +absl::StatusOr SubgroupWithTileAssignmentHelper( + nb::ndarray tile_assignment, + absl::Span subgroup_types) { + return HloSharding::Subgroup(NDArrayToArray(tile_assignment), subgroup_types); +} + +nb::ndarray<> LiteralToNdarray(Literal &obj) { + const Shape &shape = obj.shape(); + + if (!shape.has_layout()) { + throw XlaRuntimeError( + "Creating an array is only supported for Literals with a layout."); + } + + const Layout &layout = shape.layout(); + + if (!layout.tiles().empty()) { + throw XlaRuntimeError( + "Creating an array from a tiled Literal is not supported."); + } + + if (!LayoutUtil::IsDenseArray(shape)) { + throw XlaRuntimeError( + "Creating an array is only supported for dense Literals."); + } + + xla::PrimitiveType primitive_type = shape.element_type(); + nb::dlpack::dtype dtype = + ValueOrThrow(PrimitiveTypeToNbDLDataType(primitive_type)); + + absl::Span dimensions = shape.dimensions(); + std::vector unsigned_dimensions(dimensions.begin(), dimensions.end()); + auto strides = StridesForShape(primitive_type, dimensions, layout); + + return nb::ndarray<>(obj.untyped_data(), unsigned_dimensions.size(), + unsigned_dimensions.data(), {}, strides.data(), dtype, + nb::device::cpu::value, 0); +} + +} // namespace + +void BuildXlaCompilerSubmodule(nb::module_ &m) { + // Shapes + nb::class_ layout_class(m, "Layout"); + layout_class.def(nb::init>()) + .def("__init__", + [](Layout *self, nb::sequence minor_to_major, nb::sequence tiling, + int64_t element_size_in_bits) { + std::vector xla_tiles; + xla_tiles.reserve(nb::len(tiling.ptr())); + for (auto tile : tiling) { + xla_tiles.push_back(Tile( + SequenceToVector(nb::cast(tile)))); + } + std::vector xla_minor_to_major = + SequenceToVector(minor_to_major); + new (self) + Layout(xla_minor_to_major, xla_tiles, element_size_in_bits); + }) + .def("minor_to_major", + [](Layout layout) { return SpanToNbTuple(layout.minor_to_major()); }) + .def("element_size_in_bits", &Layout::element_size_in_bits) + .def("tiling", + [](Layout layout) { + std::vector result; + result.reserve(layout.tiles().size()); + for (auto &t : layout.tiles()) { + result.push_back(SpanToNbTuple(t.dimensions())); + } + return result; + }) + .def("__eq__", [](const Layout &layout, + const Layout &other) { return layout == other; }) + .def("__ne__", [](const Layout &layout, + const Layout &other) { return layout != other; }) + .def("__str__", &Layout::ToString) + .def("__hash__", + [](const Layout &layout) { return absl::HashOf(layout); }) + .def("to_string", &Layout::ToString) + .def("__getstate__", + [](const Layout &self) -> nb::tuple { + auto proto = self.ToProto(); + std::string result; + if (!tsl::SerializeToStringDeterministic(proto, &result)) { + // throw converted by PyBind to a Python RuntimeError. + throw XlaRuntimeError( + absl::StrCat("Layout.py_pickle: ", + "SerializeToStringDeterministic failed")); + } + return nb::make_tuple(nb::bytes(result.data(), result.size())); + }) + .def("__setstate__", [](Layout *self, nb::tuple t) { + LayoutProto result; + nb::bytes serialized = nb::cast(t[0]); + result.ParseFromArray(serialized.c_str(), serialized.size()); + new (self) Layout(ValueOrThrow(Layout::FromProto(result))); + }); + + nb::class_ shape_class(m, "Shape"); + shape_class + .def("__init__", + [](Shape *self, const std::string &s) { + new (self) Shape(ValueOrThrow(ParseShape(s))); + }) + .def_static( + "tuple_shape", + [](std::vector shapes) -> Shape { + return ShapeUtil::MakeTupleShape(shapes); + }, + "Constructs a tuple shape.") + .def_static("array_shape", + xla::ValueOrThrowWrapper( + [](PrimitiveType type, nb::sequence dims_seq, + std::optional layout_seq, + std::optional> dynamic_dimensions) + -> absl::StatusOr { + std::vector dims = + SequenceToVector(dims_seq); + if (layout_seq) { + std::vector layout = + SequenceToVector(*layout_seq); + return MakeShapeWithDenseLayout(type, dims, layout, + dynamic_dimensions); + } else { + return MakeShapeWithDenseLayout( + type, dims, std::nullopt, dynamic_dimensions); + } + }), + "Constructs an array shape.", nb::arg("type"), + nb::arg("dims"), nb::arg("layout").none() = std::nullopt, + nb::arg("dynamic_dimensions").none() = std::nullopt) + .def_static( + "array_shape", + xla::ValueOrThrowWrapper( + [](nb_dtype dtype, nb::sequence dims_seq, + std::optional layout_seq, + std::optional> dynamic_dimensions) + -> absl::StatusOr { + PrimitiveType type = ValueOrThrow(DtypeToPrimitiveType(dtype)); + std::vector dims = SequenceToVector(dims_seq); + if (layout_seq) { + std::vector layout = + SequenceToVector(*layout_seq); + return MakeShapeWithDenseLayout(type, dims, layout, + dynamic_dimensions); + } else { + return MakeShapeWithDenseLayout(type, dims, std::nullopt, + dynamic_dimensions); + } + }), + "Constructs an array shape.", nb::arg("type"), nb::arg("dims"), + nb::arg("layout").none() = std::nullopt, + nb::arg("dynamic_dimensions").none() = std::nullopt) + .def_static("token_shape", []() { return ShapeUtil::MakeTokenShape(); }) + .def_static( + "scalar_shape", + [](PrimitiveType type) -> Shape { + return ShapeUtil::MakeScalarShape(type); + }, + "Constructs a scalar shape.", nb::arg("type")) + .def_static( + "scalar_shape", + [](nb_dtype dtype) -> Shape { + PrimitiveType type = xla::ValueOrThrow(DtypeToPrimitiveType(dtype)); + return ShapeUtil::MakeScalarShape(type); + }, + "Constructs a scalar shape.", nb::arg("type")) + .def("dimensions", + [](const Shape &shape) -> nb::tuple { + return SpanToNbTuple(shape.dimensions()); + }) + .def("layout", + [](const Shape &shape) -> Layout { return shape.layout(); }) + .def("xla_element_type", &Shape::element_type) + .def("element_type", + [](const Shape &shape) { + return xla::ValueOrThrow( + PrimitiveTypeToNbDtype(shape.element_type())); + }) + .def("numpy_dtype", + [](const Shape &shape) { + if (shape.IsTuple()) { + return nb_dtype("O"); + } + return xla::ValueOrThrow( + PrimitiveTypeToNbDtype(shape.element_type())); + }) + .def("is_tuple", &Shape::IsTuple) + .def("is_array", &Shape::IsArray) + .def("is_token", &Shape::IsToken) + .def("is_static", &Shape::is_static) + .def("is_dynamic", &Shape::is_dynamic) + .def("is_dynamic_dimension", &Shape::is_dynamic_dimension, + nb::arg("dimension")) + .def("set_dynamic_dimension", &Shape::set_dynamic_dimension, + nb::arg("dimension"), nb::arg("is_dynamic")) + .def("rank", &Shape::dimensions_size) + .def("to_serialized_proto", + [](const Shape &shape) { + ShapeProto proto = shape.ToProto(); + std::string s = proto.SerializeAsString(); + return nb::bytes(s.data(), s.size()); + }) + .def("tuple_shapes", + [](const Shape &shape) { + return std::vector(shape.tuple_shapes()); + }) + .def("leaf_count", + [](const Shape &shape) { return ShapeUtil::GetLeafCount(shape); }) + .def( + "with_major_to_minor_layout_if_absent", + [](const Shape &shape) { + Shape out = shape; + ShapeUtil::ForEachMutableSubshape( + &out, [](Shape *subshape, const ShapeIndex &) { + if (!subshape->has_layout()) { + LayoutUtil::SetToDefaultLayout(subshape); + } + }); + return out; + }, + "Returns a copy of a shape with missing layouts set to " + "major-to-minor.") + .def("__eq__", [](const Shape &shape, + const Shape &other) { return shape == other; }) + .def("__ne__", [](const Shape &shape, + const Shape &other) { return shape != other; }) + .def("__hash__", [](const Shape &shape) { return absl::HashOf(shape); }) + .def("__repr__", [](const Shape &shape) { + return shape.ToString(/*print_layout=*/true); + }); + + nb::class_(m, "ProgramShape") + .def( + "__init__", + [](ProgramShape *self, absl::Span params, Shape result) { + new (self) ProgramShape(); + for (const Shape ¶m : params) { + self->AddParameter(param, ""); + } + *self->mutable_result() = result; + }) + .def("parameter_shapes", + static_cast &(ProgramShape::*)() const>( + &ProgramShape::parameters)) + .def("result_shape", &ProgramShape::result) + .def("__repr__", &ProgramShape::ToString); + + // Literals + nb::class_(m, "Literal") + .def(nb::init()) + .def("__repr__", &Literal::ToString) + .def( + "__array__", + [](std::shared_ptr obj, std::optional dtype, + std::optional copy) { + // Provides the interface required by numpy to create a np.ndarray. + // Currently don't support the __dl_pack__ interface but can be + // added with very little effort it if needed. + + nb::ndarray np_array(LiteralToNdarray(*obj)); + + if (dtype.has_value()) { + throw XlaRuntimeError( + "Passing of dtype to __array__ not currently supported."); + } + + if (copy.has_value() && *copy) { + // when a copy is requested we _must_ return a copy: + // https://numpy.org/doc/2.1/reference/generated/numpy.ndarray.__array__.html + return np_array.cast(nb::rv_policy::copy); + } + + return np_array.cast(nb::rv_policy::reference_internal, + nb::cast(obj)); + }, + nb::arg("dtype").none() = nb::none(), + nb::arg("copy").none() = nb::none()) + .def("shape", &Literal::shape); + + nb::class_(m, "XlaComputation") + .def("__init__", + [](XlaComputation *self, + const nb::bytes &serialized_hlo_module_proto) { + HloModuleProto proto; + proto.ParseFromArray(serialized_hlo_module_proto.c_str(), + serialized_hlo_module_proto.size()); + new (self) XlaComputation(proto); + }) + .def("get_hlo_module", xla::ValueOrThrowWrapper(GetHloModule)) + .def("program_shape", + xla::ValueOrThrowWrapper(&XlaComputation::GetProgramShape)) + .def("name", &XlaComputation::name) + .def("as_serialized_hlo_module_proto", + xla::ValueOrThrowWrapper(GetComputationSerializedProto)) + .def("as_hlo_text", xla::ValueOrThrowWrapper(GetComputationHloText), + nb::arg("print_large_constants") = false) + .def("as_hlo_dot_graph", + xla::ValueOrThrowWrapper(GetComputationHloDotGraph)) + .def("hash", xla::ValueOrThrowWrapper(HashComputation)) + .def("as_hlo_module", xla::ValueOrThrowWrapper(GetHloModule)); + + nb::class_ hlo_print_options_class(m, "HloPrintOptions"); + hlo_print_options_class.def(nb::init<>()) + .def_static("short_parsable", &HloPrintOptions::ShortParsable) + .def_static("canonical", &HloPrintOptions::Canonical) + .def_static("fingerprint", &HloPrintOptions::Fingerprint) + .def_prop_rw("print_large_constants", + &HloPrintOptions::print_large_constants, + &HloPrintOptions::set_print_large_constants) + .def_prop_rw("print_metadata", &HloPrintOptions::print_metadata, + &HloPrintOptions::set_print_metadata) + .def_prop_rw("print_backend_config", + &HloPrintOptions::print_backend_config, + &HloPrintOptions::set_print_backend_config) + .def_prop_rw("print_result_shape", &HloPrintOptions::print_result_shape, + &HloPrintOptions::set_print_result_shape) + .def_prop_rw("print_operand_shape", &HloPrintOptions::print_operand_shape, + &HloPrintOptions::set_print_operand_shape) + .def_prop_rw("print_operand_names", &HloPrintOptions::print_operand_names, + &HloPrintOptions::set_print_operand_names) + .def_prop_rw("print_ids", &HloPrintOptions::print_ids, + &HloPrintOptions::set_print_ids) + .def_prop_rw("print_extra_attributes", + &HloPrintOptions::print_extra_attributes, + &HloPrintOptions::set_print_extra_attributes) + .def_prop_rw("print_program_shape", &HloPrintOptions::print_program_shape, + &HloPrintOptions::set_print_program_shape) + .def_prop_rw("print_percent", &HloPrintOptions::print_percent, + &HloPrintOptions::set_print_percent) + .def_prop_rw("print_control_dependencies", + &HloPrintOptions::print_control_dependencies, + &HloPrintOptions::set_print_control_dependencies) + .def_prop_rw("compact_operands", &HloPrintOptions::compact_operands, + &HloPrintOptions::set_compact_operands) + .def_prop_rw("include_layout_in_shapes", + &HloPrintOptions::include_layout_in_shapes, + &HloPrintOptions::set_include_layout_in_shapes) + .def_prop_rw("canonicalize_instruction_names", + &HloPrintOptions::canonicalize_instruction_names, + &HloPrintOptions::set_canonicalize_instruction_names) + .def_prop_rw("canonicalize_computations", + &HloPrintOptions::canonicalize_computations, + &HloPrintOptions::set_canonicalize_computations) + .def_prop_rw("indent_amount", &HloPrintOptions::indent_amount, + &HloPrintOptions::set_indent_amount) + .def_prop_rw("is_in_nested_computation", + &HloPrintOptions::is_in_nested_computation, + &HloPrintOptions::set_is_in_nested_computation); + + // HloModule.computations() returns raw pointers. + // pybind seems to prefer smart pointers. + // We give pybind a smart pointer to a wrapper around a raw pointer to satisfy + // pybind and avoid double frees. + class ComputationWrapper { + public: + ComputationWrapper(const HloComputation *comp, + const std::shared_ptr module) + : comp_(comp), module_(module) {} + absl::string_view name() const { return comp_->name(); } + void render_html(const std::string &filename) { + std::string html = xla::ValueOrThrow(RenderGraph( + *comp_, /*label=*/"", comp_->parent()->config().debug_options(), + RenderedGraphFormat::kHtml, HloRenderOptions())); + xla::ThrowIfError(tsl::WriteStringToFile( + tsl::Env::Default(), absl::StrCat(filename, ".html"), html)); + } + + private: + const HloComputation *comp_; + // The module owns the computations: if its destructor is called, the + // computations are freed. To prevent that from happening in cases where the + // module Python object goes out of scope and gets garbage collected before + // the computations, we keep a shared_ptr to the module that originated the + // computation. + const std::shared_ptr module_; + }; + + nb::class_ hlo_computation_class(m, "HloComputation"); + + hlo_computation_class.def_prop_ro("name", &ComputationWrapper::name) + .def("render_html", &ComputationWrapper::render_html); + + nb::class_ hlo_module_class(m, "HloModule"); + hlo_module_class.def_prop_ro("name", &HloModule::name) + .def("to_string", + static_cast(&HloModule::ToString), + nb::arg("options") = HloPrintOptions()) + .def("as_serialized_hlo_module_proto", + xla::ValueOrThrowWrapper(GetHloModuleSerializedProto)) + .def("from_serialized_hlo_module_proto", + xla::ValueOrThrowWrapper(HloModuleFromSerializedProto)) + .def("computations", + [](const std::shared_ptr m) + -> std::vector> { + std::vector> computations; + for (HloComputation *comp : m->computations()) + computations.push_back( + std::make_shared(comp, m)); + return computations; + }) + .def_prop_ro("spmd_output_sharding", + [](const HloModule &m) -> std::optional { + if (!m.has_spmd_output_sharding()) return std::nullopt; + return m.spmd_output_sharding().ToProto(); + }) + .def_prop_ro("spmd_parameters_shardings", + [](const HloModule &m) + -> std::optional> { + if (!m.has_spmd_parameters_shardings()) + return std::nullopt; + std::vector param_shardings; + for (const auto ¶meter_sharding : + m.spmd_parameters_shardings()) { + param_shardings.push_back(parameter_sharding.ToProto()); + } + return param_shardings; + }); + + m.def("hlo_module_to_dot_graph", + [](const HloModule &hlo_module) -> std::string { + return xla::ValueOrThrow(RenderGraph( + *hlo_module.entry_computation(), /*label=*/"", + hlo_module.config().debug_options(), RenderedGraphFormat::kDot)); + }); + m.def( + "hlo_module_cost_analysis", + xla::ValueOrThrowWrapper([](PyClient *client, const HloModule &module) + -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(auto analysis, + client->pjrt_client()->GetHloCostAnalysis()); + TF_RETURN_IF_ERROR(module.entry_computation()->Accept(analysis.get())); + + // Convert from HloCostAnalysis::Properties to a standard map. + nb::dict ret; + analysis->properties().ForEach([&](absl::string_view key, float val) { + ret[nb::str(key.data(), key.size())] = nb::cast(val); + }); + return ret; + })); + m.def("hlo_module_from_text", + xla::ValueOrThrowWrapper( + [](const std::string &hlo_module_text) + -> absl::StatusOr> { + auto hlo_module = + xla::ParseAndReturnUnverifiedModule(hlo_module_text); + TF_RETURN_IF_ERROR(hlo_module.status()); + std::shared_ptr result(std::move(*hlo_module)); + return result; + })); + + // Device assignments + nb::class_(m, "DeviceAssignment") + .def_static( + "create", + xla::ValueOrThrowWrapper([](nb::ndarray> array) + -> absl::StatusOr { + if (array.ndim() != 2) { + return InvalidArgument( + "Argument to DeviceAssignment constructor must be a " + "2D array, received an %dD array.", + array.ndim()); + } + DeviceAssignment result(array.shape(0), array.shape(1)); + for (int i = 0; i < array.shape(0); ++i) { + for (int j = 0; j < array.shape(1); ++j) { + result(i, j) = array(i, j); + } + } + return result; + })) + .def("replica_count", &DeviceAssignment::replica_count) + .def("computation_count", &DeviceAssignment::computation_count) + .def("__repr__", &DeviceAssignment::ToString) + .def("serialize", + xla::ValueOrThrowWrapper( + [](const DeviceAssignment &da) -> absl::StatusOr { + DeviceAssignmentProto proto; + da.Serialize(&proto); + std::string result; + if (!tsl::SerializeToStringDeterministic(proto, &result)) { + return Unknown( + "Failed to serialize the DeviceAssignmentProto."); + } + return nb::bytes(result.data(), result.size()); + })); + + nb::class_ compile_options(m, "CompileOptions"); + compile_options + .def("__init__", + [](CompileOptions *self) { + new (self) CompileOptions(); + DebugOptions *debug_options = + self->executable_build_options.mutable_debug_options(); + // Sets fast-math-disabling default options expected by JAX. + debug_options->set_xla_cpu_enable_fast_min_max(false); + debug_options->set_xla_gpu_enable_fast_min_max(false); + }) + .def("__getstate__", + [](const CompileOptions &self) -> nb::tuple { + auto proto = ValueOrThrow(self.ToProto()); + std::string result; + if (!tsl::SerializeToStringDeterministic(proto, &result)) { + // throw converted by PyBind to a Python RuntimeError. + throw XlaRuntimeError( + absl::StrCat("CompileOptions.py_pickle: ", + "SerializeToStringDeterministic failed")); + } + return nb::make_tuple(nb::bytes(result.data(), result.size())); + }) + .def("__setstate__", + [](CompileOptions *self, nb::tuple t) { + CompileOptionsProto result; + nb::bytes serialized = nb::cast(t[0]); + result.ParseFromArray(serialized.c_str(), serialized.size()); + new (self) CompileOptions( + ValueOrThrow(CompileOptions::FromProto(result))); + }) + .def("SerializeAsString", + [](const CompileOptions &self) -> nb::bytes { + auto proto = ValueOrThrow(self.ToProto()); + std::string result; + if (!tsl::SerializeToStringDeterministic(proto, &result)) { + // throw converted by PyBind to a Python RuntimeError. + throw XlaRuntimeError( + absl::StrCat("CompileOptions.SerializeAsString: ", + "SerializeToStringDeterministic failed")); + } + return nb::bytes(result.data(), result.size()); + }) + .def_static("ParseFromString", + [](nb::bytes s) { + CompileOptionsProto result; + result.ParseFromArray(s.c_str(), s.size()); + return ValueOrThrow(CompileOptions::FromProto(result)); + }) + .def_rw("argument_layouts", &CompileOptions::argument_layouts) + .def_rw("parameter_is_tupled_arguments", + &CompileOptions::parameter_is_tupled_arguments) + .def_rw("compile_portable_executable", + &CompileOptions::compile_portable_executable) + .def_ro("executable_build_options", + &CompileOptions::executable_build_options) + .def_rw("env_option_overrides", &CompileOptions::env_option_overrides) + // TODO(phawkins): the following fields exist for backward compatibility. + // Remove them after JAX has been updated not to use them. + .def_rw("tuple_arguments", &CompileOptions::parameter_is_tupled_arguments) + .def_prop_rw( + "num_replicas", + [](const CompileOptions &options) { + return options.executable_build_options.num_replicas(); + }, + [](CompileOptions &options, int num_replicas) { + options.executable_build_options.set_num_replicas(num_replicas); + }) + .def_prop_rw( + "num_partitions", + [](const CompileOptions &options) { + return options.executable_build_options.num_partitions(); + }, + [](CompileOptions &options, int num_partitions) { + options.executable_build_options.set_num_partitions(num_partitions); + }) + .def_prop_rw( + "profile_version", + [](const CompileOptions &options) { return options.profile_version; }, + [](CompileOptions &options, int64_t profile_version) { + options.profile_version = profile_version; + }) + .def_prop_rw( + "device_assignment", + [](const CompileOptions &options) -> std::optional { + return options.executable_build_options.has_device_assignment() + ? std::optional( + options.executable_build_options + .device_assignment()) + : std::nullopt; + }, + [](CompileOptions &options, + const DeviceAssignment &device_assignment) { + options.executable_build_options.set_device_assignment( + device_assignment); + }); + + // Custom-call targets. + m.def( + "register_custom_call_target", + [](nb::object fn_name_py, nb::object fn, const std::string &platform, + int api_version, XLA_FFI_Handler_Traits traits) { + std::string fn_name; + if (!nb::try_cast(fn_name_py, fn_name)) { + nb::bytes bytes = nb::cast(fn_name_py); + fn_name = std::string(bytes.c_str(), bytes.size()); + } + xla::ThrowIfError(PyRegisterCustomCallTarget( + fn_name, std::move(fn), platform, api_version, traits)); + }, + nb::arg("fn_name"), nb::arg("fn"), nb::arg("platform"), + nb::arg("api_version") = 0, nb::arg("traits") = 0); + + m.def( + "custom_call_targets", + [](const std::string &platform) -> nb::dict { + nb::dict targets; + for (const auto &[name, target] : + CustomCallTargetRegistry::Global()->registered_symbols(platform)) { + targets[nb::str(name.data(), name.size())] = nb::capsule(target); + } + + auto ffi_handlers = ffi::StaticRegisteredHandlers(platform); + if (!ffi_handlers.ok()) return targets; + + for (const auto &[name, registration] : *ffi_handlers) { + nb::dict bundle; + auto export_handler = [&](absl::string_view name, + XLA_FFI_Handler *h) { + if (h != nullptr) { + bundle[nb::str(name.data(), name.size())] = + nb::capsule(reinterpret_cast(h)); + } + }; + export_handler("prepare", registration.bundle.prepare); + export_handler("initialize", registration.bundle.initialize); + export_handler("execute", registration.bundle.execute); + targets[nb::str(name.data(), name.size())] = std::move(bundle); + } + return targets; + }, + nb::arg("platform")); + + nb::enum_(m, "AutotuneCacheMode") + .value("UNSPECIFIED", DebugOptions::AUTOTUNE_CACHE_MODE_UNSPECIFIED) + .value("UPDATE", DebugOptions::AUTOTUNE_CACHE_MODE_UPDATE) + .value("READ", DebugOptions::AUTOTUNE_CACHE_MODE_READ); + + m.def( + "register_custom_type_id", + [](absl::string_view type_name, nb::object type_id) { + xla::ThrowIfError(PyRegisterCustomTypeId(type_name, type_id)); + }, + nb::arg("type_name"), nb::arg("type_id")); + + nb::class_(m, "DebugOptions") + .def("__repr__", &DebugOptions::DebugString) + .def_prop_rw("xla_backend_optimization_level", + &DebugOptions::xla_backend_optimization_level, + &DebugOptions::set_xla_backend_optimization_level) + .def_prop_rw("xla_cpu_enable_fast_math", + &DebugOptions::xla_cpu_enable_fast_math, + &DebugOptions::set_xla_cpu_enable_fast_math) + .def_prop_rw("xla_cpu_enable_xprof_traceme", + &DebugOptions::xla_cpu_enable_xprof_traceme, + &DebugOptions::set_xla_cpu_enable_xprof_traceme) + .def_prop_rw("xla_cpu_fast_math_honor_infs", + &DebugOptions::xla_cpu_fast_math_honor_infs, + &DebugOptions::set_xla_cpu_fast_math_honor_infs) + .def_prop_rw("xla_cpu_fast_math_honor_nans", + &DebugOptions::xla_cpu_fast_math_honor_nans, + &DebugOptions::set_xla_cpu_fast_math_honor_nans) + .def_prop_rw("xla_cpu_fast_math_honor_division", + &DebugOptions::xla_cpu_fast_math_honor_division, + &DebugOptions::set_xla_cpu_fast_math_honor_division) + .def_prop_rw("xla_cpu_fast_math_honor_functions", + &DebugOptions::xla_cpu_fast_math_honor_functions, + &DebugOptions::set_xla_cpu_fast_math_honor_functions) + .def_prop_rw("xla_detailed_logging", &DebugOptions::xla_detailed_logging, + &DebugOptions::set_xla_detailed_logging) + .def_prop_rw("xla_enable_dumping", &DebugOptions::xla_enable_dumping, + &DebugOptions::set_xla_enable_dumping) + .def_prop_rw("xla_gpu_enable_fast_min_max", + &DebugOptions::xla_gpu_enable_fast_min_max, + &DebugOptions::set_xla_gpu_enable_fast_min_max) + .def_prop_rw("xla_gpu_dump_autotune_results_to", + &DebugOptions::xla_gpu_dump_autotune_results_to, + [](DebugOptions *self, std::string value) { + self->set_xla_gpu_dump_autotune_results_to(value); + }) + .def_prop_rw("xla_gpu_load_autotune_results_from", + &DebugOptions::xla_gpu_load_autotune_results_from, + [](DebugOptions *self, std::string value) { + self->set_xla_gpu_load_autotune_results_from(value); + }) + .def_prop_rw("xla_gpu_cuda_data_dir", + &DebugOptions::xla_gpu_cuda_data_dir, + [](DebugOptions *self, std::string value) { + self->set_xla_gpu_cuda_data_dir(value); + }) + .def_prop_rw("xla_llvm_disable_expensive_passes", + &DebugOptions::xla_llvm_disable_expensive_passes, + &DebugOptions::set_xla_llvm_disable_expensive_passes) + .def_prop_rw( + "xla_disable_hlo_passes", + [](DebugOptions *self) { + return absl::StrJoin(self->xla_disable_hlo_passes(), ","); + }, + [](DebugOptions *self, std::string value) { + self->clear_xla_disable_hlo_passes(); + for (const auto &passname : + std::vector(absl::StrSplit(value, ','))) { + self->add_xla_disable_hlo_passes(passname); + } + }) + .def_prop_rw( + "xla_enable_hlo_passes_only", + [](DebugOptions *self) { + return absl::StrJoin(self->xla_enable_hlo_passes_only(), ","); + }, + [](DebugOptions *self, std::string value) { + self->clear_xla_enable_hlo_passes_only(); + for (const auto &passname : + std::vector(absl::StrSplit(value, ','))) { + self->add_xla_enable_hlo_passes_only(passname); + } + }) + .def_prop_rw("xla_test_all_input_layouts", + &DebugOptions::xla_test_all_input_layouts, + &DebugOptions::set_xla_test_all_input_layouts) + .def_prop_rw("xla_force_host_platform_device_count", + &DebugOptions::xla_force_host_platform_device_count, + &DebugOptions::set_xla_force_host_platform_device_count) + .def_prop_rw("xla_dump_to", &DebugOptions::xla_dump_to, + [](DebugOptions *self, std::string value) { + self->set_xla_dump_to(value); + }) + .def_prop_rw("xla_dump_hlo_module_re", + &DebugOptions::xla_dump_hlo_module_re, + [](DebugOptions *self, std::string value) { + self->set_xla_dump_hlo_module_re(value); + }) + .def_prop_rw("xla_dump_hlo_pass_re", &DebugOptions::xla_dump_hlo_pass_re, + [](DebugOptions *self, std::string value) { + self->set_xla_dump_hlo_pass_re(value); + }) + .def_prop_rw("xla_dump_hlo_as_text", &DebugOptions::xla_dump_hlo_as_text, + &DebugOptions::set_xla_dump_hlo_as_text) + .def_prop_rw("xla_dump_hlo_as_proto", + &DebugOptions::xla_dump_hlo_as_proto, + &DebugOptions::set_xla_dump_hlo_as_proto) + .def_prop_rw("xla_dump_hlo_as_dot", &DebugOptions::xla_dump_hlo_as_dot, + &DebugOptions::set_xla_dump_hlo_as_dot) + .def_prop_rw("xla_dump_hlo_as_url", &DebugOptions::xla_dump_hlo_as_url, + &DebugOptions::set_xla_dump_hlo_as_url) + .def_prop_rw("xla_dump_hlo_as_html", &DebugOptions::xla_dump_hlo_as_html, + &DebugOptions::set_xla_dump_hlo_as_html) + .def_prop_rw("xla_dump_fusion_visualization", + &DebugOptions::xla_dump_fusion_visualization, + &DebugOptions::set_xla_dump_fusion_visualization) + .def_prop_rw("xla_dump_hlo_snapshots", + &DebugOptions::xla_dump_hlo_snapshots, + &DebugOptions::set_xla_dump_hlo_snapshots) + .def_prop_rw("xla_dump_max_hlo_modules", + &DebugOptions::xla_dump_max_hlo_modules, + &DebugOptions::set_xla_dump_max_hlo_modules) + .def_prop_rw("xla_dump_module_metadata", + &DebugOptions::xla_dump_module_metadata, + &DebugOptions::set_xla_dump_module_metadata) + .def_prop_rw("xla_dump_compress_protos", + &DebugOptions::xla_dump_compress_protos, + &DebugOptions::set_xla_dump_compress_protos) + .def_prop_rw("xla_dump_hlo_as_long_text", + &DebugOptions::xla_dump_hlo_as_long_text, + &DebugOptions::set_xla_dump_hlo_as_long_text) + .def_prop_rw("xla_dump_disable_metadata", + &DebugOptions::xla_dump_disable_metadata, + &DebugOptions::set_xla_dump_disable_metadata) + .def_prop_rw("xla_dump_hlo_pipeline_re", + &DebugOptions::xla_dump_hlo_pipeline_re, + [](DebugOptions *self, std::string value) { + self->set_xla_dump_hlo_pipeline_re(value); + }) + .def_prop_rw("xla_gpu_dump_autotune_logs_to", + &DebugOptions::xla_gpu_dump_autotune_logs_to, + [](DebugOptions *self, std::string value) { + self->set_xla_gpu_dump_autotune_logs_to(value); + }) + .def_prop_rw("xla_gpu_kernel_cache_file", + &DebugOptions::xla_gpu_kernel_cache_file, + [](DebugOptions *self, std::string value) { + self->set_xla_gpu_kernel_cache_file(value); + }) + .def_prop_rw( + "xla_gpu_enable_llvm_module_compilation_parallelism", + &DebugOptions::xla_gpu_enable_llvm_module_compilation_parallelism, + &DebugOptions::set_xla_gpu_enable_llvm_module_compilation_parallelism) + .def_prop_rw("xla_gpu_per_fusion_autotune_cache_dir", + &DebugOptions::xla_gpu_per_fusion_autotune_cache_dir, + [](DebugOptions *self, std::string value) { + self->set_xla_gpu_per_fusion_autotune_cache_dir(value); + }) + .def_prop_rw("xla_gpu_experimental_autotune_cache_mode", + &DebugOptions::xla_gpu_experimental_autotune_cache_mode, + &DebugOptions::set_xla_gpu_experimental_autotune_cache_mode); + + nb::class_(m, "ExecutableBuildOptions") + .def(nb::init<>()) + .def("__repr__", &ExecutableBuildOptions::ToString) + .def_prop_rw( + "fdo_profile", + [](const ExecutableBuildOptions &options) { + return nb::bytes(options.fdo_profile().data(), + options.fdo_profile().size()); + }, + [](ExecutableBuildOptions &options, nb::bytes fdo_profile) { + options.set_fdo_profile( + std::string(fdo_profile.c_str(), fdo_profile.size())); + }) + .def_prop_rw( + "result_layout", + [](const ExecutableBuildOptions &options) -> std::optional { + return options.result_layout() + ? std::optional(*options.result_layout()) + : std::nullopt; + }, + &ExecutableBuildOptions::set_result_layout) + .def_prop_rw("num_replicas", &ExecutableBuildOptions::num_replicas, + &ExecutableBuildOptions::set_num_replicas) + .def_prop_rw("num_partitions", &ExecutableBuildOptions::num_partitions, + &ExecutableBuildOptions::set_num_partitions) + .def_prop_ro("debug_options", + &ExecutableBuildOptions::mutable_debug_options, + nb::rv_policy::reference, nb::keep_alive<1, 0>()) + .def_prop_rw( + "device_assignment", + [](const ExecutableBuildOptions &options) + -> std::optional { + return options.has_device_assignment() + ? std::optional( + options.device_assignment()) + : std::nullopt; + }, + &ExecutableBuildOptions::set_device_assignment) + .def("compilation_environments_from_serialized_proto", + [](ExecutableBuildOptions &options, + const nb::bytes &serialized_proto) { + xla::CompilationEnvironmentsProto env_proto; + env_proto.ParseFromArray(serialized_proto.c_str(), + serialized_proto.size()); + auto comp_envs = xla::ValueOrThrow( + xla::CompilationEnvironments::CreateFromProto(env_proto)); + *options.mutable_comp_envs() = std::move(*comp_envs); + }) + .def_prop_rw("exec_time_optimization_effort", + &ExecutableBuildOptions::exec_time_optimization_effort, + &ExecutableBuildOptions::set_exec_time_optimization_effort) + .def_prop_rw("memory_fitting_effort", + &ExecutableBuildOptions::memory_fitting_effort, + &ExecutableBuildOptions::set_memory_fitting_effort) + .def_prop_rw( + "optimization_level", &ExecutableBuildOptions::optimization_level, + [](ExecutableBuildOptions &options, int value) { + options.set_optimization_level( + static_cast(value)); + }) + .def_prop_rw( + "memory_fitting_level", &ExecutableBuildOptions::memory_fitting_level, + [](ExecutableBuildOptions &options, int value) { + options.set_memory_fitting_level( + static_cast(value)); + }) + .def_prop_rw("use_spmd_partitioning", + &ExecutableBuildOptions::use_spmd_partitioning, + &ExecutableBuildOptions::set_use_spmd_partitioning) + .def_prop_rw("use_auto_spmd_partitioning", + &ExecutableBuildOptions::use_auto_spmd_partitioning, + &ExecutableBuildOptions::set_use_auto_spmd_partitioning) + .def_prop_rw( + "auto_spmd_partitioning_mesh_shape", + &ExecutableBuildOptions::auto_spmd_partitioning_mesh_shape, + &ExecutableBuildOptions::set_auto_spmd_partitioning_mesh_shape) + .def_prop_rw("auto_spmd_partitioning_mesh_ids", + &ExecutableBuildOptions::auto_spmd_partitioning_mesh_ids, + &ExecutableBuildOptions::set_auto_spmd_partitioning_mesh_ids) + .def_prop_rw( + "allow_spmd_sharding_propagation_to_parameters", + [](const ExecutableBuildOptions &options) -> std::vector { + return std::vector( + options.allow_spmd_sharding_propagation_to_parameters().begin(), + options.allow_spmd_sharding_propagation_to_parameters().end()); + }, + [](ExecutableBuildOptions &options, std::vector values) { + absl::InlinedVector v(values.begin(), values.end()); + options.set_allow_spmd_sharding_propagation_to_parameters(v); + }) + .def_prop_rw( + "allow_spmd_sharding_propagation_to_output", + [](const ExecutableBuildOptions &options) -> std::vector { + return std::vector( + options.allow_spmd_sharding_propagation_to_output().begin(), + options.allow_spmd_sharding_propagation_to_output().end()); + }, + [](ExecutableBuildOptions &options, std::vector values) { + absl::InlinedVector v(values.begin(), values.end()); + options.set_allow_spmd_sharding_propagation_to_output(v); + }) + .def_prop_rw("use_shardy_partitioner", + &ExecutableBuildOptions::use_shardy_partitioner, + &ExecutableBuildOptions::set_use_shardy_partitioner); + + nb::enum_ op_sharding_type(m, "OpSharding_Type", + nb::is_arithmetic()); + op_sharding_type.value("REPLICATED", OpSharding::REPLICATED) + .value("MAXIMAL", OpSharding::MAXIMAL) + .value("MANUAL", OpSharding::MANUAL) + .value("TUPLE", OpSharding::TUPLE) + .value("OTHER", OpSharding::OTHER) + .value("UNKNOWN", OpSharding::UNKNOWN); + + nb::enum_ op_sharding_shard_group_type( + m, "OpSharding_ShardGroupType"); + op_sharding_shard_group_type.value("AS", OpSharding::AS) + .value("LIKE", OpSharding::LIKE); + + nb::class_ op_sharding(m, "OpSharding"); + op_sharding + .def_prop_ro_static( + "Type", + [op_sharding_type](const nb::object &) { return op_sharding_type; }) + .def_prop_ro_static("ShardGroupType", + [op_sharding_shard_group_type](const nb::object &) { + return op_sharding_shard_group_type; + }) + .def(nb::init<>()) + .def("__getstate__", + [](const OpSharding &self) { + std::string serialized = self.SerializeAsString(); + return nb::make_tuple( + nb::bytes(serialized.data(), serialized.size())); + }) + .def("__setstate__", + [](OpSharding *self, nb::tuple t) { + new (self) OpSharding(); + nb::bytes serialized = nb::cast(t[0]); + self->ParseFromArray(serialized.c_str(), serialized.size()); + }) + .def_prop_rw("type", &xla::OpSharding::type, &xla::OpSharding::set_type) + .def_prop_rw("replicate_on_last_tile_dim", + &xla::OpSharding::replicate_on_last_tile_dim, + &xla::OpSharding::set_replicate_on_last_tile_dim) + .def_prop_rw("is_shard_group", &xla::OpSharding::is_shard_group, + &xla::OpSharding::set_is_shard_group) + .def_prop_rw("shard_group_id", &xla::OpSharding::shard_group_id, + &xla::OpSharding::set_shard_group_id) + .def_prop_rw("shard_group_type", &xla::OpSharding::shard_group_type, + &xla::OpSharding::set_shard_group_type) + .def("__repr__", + [](const xla::OpSharding &self) { return self.DebugString(); }) + .def("ParseFromString", + [](OpSharding &sharding, const nb::bytes &s) { + sharding.ParseFromArray(s.c_str(), s.size()); + }) + .def("SerializeToString", + [](const OpSharding &sharding) { + std::string serialized = sharding.SerializeAsString(); + return nb::bytes(serialized.data(), serialized.size()); + }) + .def("clone", + [](const OpSharding &sharding) { return OpSharding(sharding); }); + DefRepeatedProperty(op_sharding, "tile_assignment_dimensions", + &xla::OpSharding::mutable_tile_assignment_dimensions); + DefRepeatedProperty(op_sharding, "tile_assignment_devices", + &xla::OpSharding::mutable_tile_assignment_devices); + DefRepeatedProperty(op_sharding, "iota_reshape_dims", + &xla::OpSharding::mutable_iota_reshape_dims); + DefRepeatedProperty(op_sharding, "iota_transpose_perm", + &xla::OpSharding::mutable_iota_transpose_perm); + DefRepeatedProperty(op_sharding, "tuple_shardings", + &xla::OpSharding::mutable_tuple_shardings); + DefRepeatedEnumProperty(op_sharding, "last_tile_dims", + &xla::OpSharding::mutable_last_tile_dims); + + nb::class_ hlo_sharding(m, "HloSharding"); + hlo_sharding + .def_static("from_proto", + xla::ValueOrThrowWrapper(xla::HloSharding::FromProto)) + .def_static("from_string", xla::ValueOrThrowWrapper(xla::ParseSharding)) + .def_static( + "tuple_sharding", + [](xla::Shape shape, + std::vector shardings) -> xla::HloSharding { + return HloSharding::Tuple(shape, shardings); + }, + "Constructs a tuple sharding.") + .def_static( + "iota_tile", xla::ValueOrThrowWrapper(IotaTileHelper), + nb::arg("dims"), + nb::arg("reshape_dims") = absl::Span(), + nb::arg("transpose_perm") = absl::Span(), + nb::arg("subgroup_types") = absl::Span()) + .def_static("manual", [] { return HloSharding::Manual(); }) + .def_static("replicate", [] { return HloSharding::Replicate(); }) + .def_static("unknown", [] { return HloSharding::Unknown(); }) + .def_static( + "subgroup_with_device_ordering", + xla::ValueOrThrowWrapper(SubgroupWithTileAssignmentHelper), + nb::arg("tile_assignment"), + nb::arg("subgroup_types") = absl::Span()) + .def("__eq__", [](const xla::HloSharding &a, + const xla::HloSharding &b) { return a == b; }) + .def("__hash__", + [](const xla::HloSharding &self) { return absl::HashOf(self); }) + .def("is_replicated", &xla::HloSharding::IsReplicated) + .def("is_manual", &xla::HloSharding::IsManual) + .def("is_unknown", &xla::HloSharding::IsUnknown) + .def("is_tiled", &xla::HloSharding::IsTiled) + .def("is_maximal", &xla::HloSharding::IsTileMaximal) + .def("tile", [](const xla::HloSharding &self, + xla::Shape shape) { return self.TileShape(shape); }) + // tile_assignment.array() is computed using an internal cache, + // which is why nb::lock_self() is required. It may be preferable to move + // this locking into the TileAssignment class if we find it to race with + // non-Python users of that class. + .def( + "tuple_elements", + [](const xla::HloSharding &self) { return self.tuple_elements(); }, + nb::lock_self()) + .def( + "num_devices", + [](const xla::HloSharding &self) { + return self.tile_assignment().num_elements(); + }, + nb::lock_self()) + .def( + "num_dimensions", + [](const xla::HloSharding &self) { + return self.tile_assignment().num_dimensions(); + }, + nb::lock_self()) + .def( + "tile_assignment_dimensions", + [](const xla::HloSharding &self) { + absl::Span span = + self.tile_assignment().dimensions(); + CHECK(span.data()); + return span; + }, + nb::lock_self()) + .def( + "tile_assignment_devices", + [](const xla::HloSharding &self) { + auto span = + absl::MakeConstSpan(self.tile_assignment().array().data(), + self.tile_assignment().num_elements()); + CHECK(span.data()); + return span; + }, + nb::lock_self()) + .def("replicate_on_last_tile_dim", + &xla::HloSharding::ReplicateOnLastTileDim) + .def("subgroup_types", &xla::HloSharding::subgroup_types) + .def("__repr__", + [](const xla::HloSharding &self) { return self.ToString(); }) + .def("to_proto", &xla::HloSharding::ToProto); +} // NOLINT(readability/fn_size) +} // namespace xla diff --git a/tests/ci_clangformat/xla_compiler.h b/tests/ci_clangformat/xla_compiler.h new file mode 100644 index 0000000..22be102 --- /dev/null +++ b/tests/ci_clangformat/xla_compiler.h @@ -0,0 +1,28 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_COMPILER_H_ +#define JAXLIB_XLA_COMPILER_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace xla { + +void BuildXlaCompilerSubmodule(nanobind::module_ &m); + +} // namespace xla + +#endif // JAXLIB_XLA_COMPILER_H_