diff --git a/.bazelrc b/.bazelrc index 28720d1885e9..b0df07e9fde1 100644 --- a/.bazelrc +++ b/.bazelrc @@ -347,6 +347,10 @@ common:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win2022/20241 common:ci_windows_amd64 --host_linkopt=/FORCE:MULTIPLE --linkopt=/FORCE:MULTIPLE common:ci_windows_amd64 --color=yes +# Used for rules_python bootstrap 1.8.0+ +build --@rules_python//python/config_settings:bootstrap_impl=script --repo_env=RULES_PYTHON_ENABLE_PIPSTAR=0 +test --@rules_python//python/config_settings:bootstrap_impl=script --repo_env=RULES_PYTHON_ENABLE_PIPSTAR=0 + # ############################################################################# # RBE config options below. These inherit the CI configs above and set the # remote execution backend and authentication options required to run builds diff --git a/.github/workflows/bazel_rocm.yml b/.github/workflows/bazel_rocm.yml index 862d42203912..f9ae72ba97a2 100644 --- a/.github/workflows/bazel_rocm.yml +++ b/.github/workflows/bazel_rocm.yml @@ -113,11 +113,12 @@ jobs: JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE: ${{ inputs.write_to_bazel_remote_cache }} JAXCI_BUILD_JAX: ${{ inputs.build_jax }} JAXCI_BUILD_JAXLIB: ${{ inputs.build_jaxlib }} + JAXCI_CLONE_MAIN_XLA: ${{ inputs.clone_main_xla }} # Begin Presubmit Naming Check - name modification requires internal check to be updated name: "linux x86, jaxlib=${{ inputs.jaxlib-version }}, ROCM=${{ inputs.rocm-version }}, Python=${{ inputs.python }}, x64=${{ inputs.enable-x64 }}, build_jax=${{ inputs.build_jax }}, build_jaxlib=${{ inputs.build_jaxlib }}" # End Presubmit Naming Check github-rocm-presubmits steps: - - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false - name: ROCm Info diff --git a/.github/workflows/build_rocm_artifacts.yml b/.github/workflows/build_rocm_artifacts.yml new file mode 100644 index 000000000000..8f9ddf184778 --- /dev/null +++ b/.github/workflows/build_rocm_artifacts.yml @@ -0,0 +1,180 @@ +# CI - Build ROCm Artifacts +# This workflow builds ROCm wheels (jax-rocm-plugin, jax-rocm-pjrt) in a ROCm container +# and uploads them to an S3 bucket. It can be triggered manually via workflow_dispatch or +# called by other workflows via workflow_call. +name: CI - Build ROCm Artifacts + +on: + workflow_dispatch: + inputs: + runner: + description: "Which runner should the workflow run on?" + type: choice + default: "linux-x86-64-1gpu-amd" + options: + - "linux-x86-64-1gpu-amd" + artifact: + description: "Which ROCm artifact to build?" + type: choice + default: "jax-rocm-plugin" + options: + - "jax-rocm-plugin" + - "jax-rocm-pjrt" + python: + description: "Which python version should the artifact be built for?" + type: choice + default: "3.12" + options: + - "3.11" + - "3.12" + - "3.13" + - "3.14" + rocm-version: + description: "Which ROCm version to build for?" + type: string + default: "7" + clone_main_xla: + description: "Should latest XLA be used?" + type: choice + default: "1" + options: + - "1" + - "0" + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: choice + default: 'no' + options: + - 'yes' + - 'no' + workflow_call: + inputs: + runner: + description: "Which runner should the workflow run on?" + type: string + default: "linux-x86-64-1gpu-amd" + artifact: + description: "Which ROCm artifact to build?" + type: string + default: "jax-rocm-plugin" + python: + description: "Which python version should the artifact be built for?" + type: string + default: "3.12" + rocm-version: + description: "Which ROCm version to build for?" + type: string + default: "7" + clone_main_xla: + description: "Should latest XLA be used?" + type: string + default: "1" + upload_artifacts_to_s3: + description: "Should the artifacts be uploaded to S3?" + default: true + type: boolean + s3_upload_uri: + description: "S3 location prefix to where the artifacts should be uploaded" + default: 's3://jax-ci-amd/jax-rocm-plugin-wheels' + type: string + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: string + default: 'no' + secrets: + AWS_ACCESS_KEY_ID: + required: true + AWS_SECRET_ACCESS_KEY: + required: true + S3_BUCKET_NAME: + required: true + outputs: + s3_upload_uri: + description: "S3 location prefix to where the artifacts were uploaded" + value: ${{ jobs.build-artifacts.outputs.s3_upload_uri }} + +permissions: + id-token: write + contents: read + actions: read + +jobs: + build-artifacts: + defaults: + run: + shell: bash + runs-on: ${{ inputs.runner }} + container: + image: "ghcr.io/rocm/jax-manylinux_2_28-rocm-7.2.0:latest" + volumes: + - /data:/data + options: >- + --device=/dev/kfd + --device=/dev/dri + --security-opt seccomp=unconfined + --group-add video + --shm-size 64G + + env: + JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" + JAXCI_OUTPUT_DIR: "${{ github.workspace }}/dist" + + name: "${{ inputs.artifact }}, py ${{ inputs.python }}, ROCm ${{ inputs.rocm-version }}" + + outputs: + s3_upload_uri: ${{ steps.store-s3-upload-uri.outputs.s3_upload_uri }} + + steps: + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + with: + persist-credentials: false + - name: Install Bazelisk + run: | + curl -fSsL -o /usr/local/bin/bazel https://github.com/bazelbuild/bazelisk/releases/latest/download/bazelisk-linux-amd64 + chmod +x /usr/local/bin/bazel + - name: Install AWS CLI + run: | + curl -fSsL "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o /tmp/awscliv2.zip + unzip -q /tmp/awscliv2.zip -d /tmp + /tmp/aws/install + rm -rf /tmp/awscliv2.zip /tmp/aws + - name: ROCm Info + run: rocminfo + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Create dist dir + run: mkdir -p $JAXCI_OUTPUT_DIR/ + - name: Build ${{ inputs.artifact }} + timeout-minutes: 120 + run: | + bazel --bazelrc=build/rocm/rocm.bazelrc run \ + --config=rocm_release_wheel \ + --config=rocm_rbe \ + --repo_env=HERMETIC_PYTHON_VERSION="${{ inputs.python }}" \ + $DEPLOY_TARGET -- $JAXCI_OUTPUT_DIR/ + env: + DEPLOY_TARGET: ${{ inputs.artifact == 'jax-rocm-plugin' && '//jaxlib/tools:deploy_rocm_plugin_wheel' || '//jaxlib/tools:deploy_rocm_pjrt_wheel' }} + - name: Configure AWS Credentials + if: ${{ inputs.upload_artifacts_to_s3 }} + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: arn:aws:iam::661452401056:role/jax-ci-rocm-jax-s3-oidc + aws-region: us-east-1 + - name: Upload artifacts to S3 + if: ${{ inputs.upload_artifacts_to_s3 }} + run: | + echo "Uploading wheels to S3..." + ls -lh $JAXCI_OUTPUT_DIR/*.whl + aws s3 cp --only-show-errors --recursive $JAXCI_OUTPUT_DIR/ "${INPUTS_S3_UPLOAD_URI}"/ + echo "Upload complete." + env: + INPUTS_S3_UPLOAD_URI: ${{ inputs.s3_upload_uri }} + - name: Store the S3 upload URI as an output + if: ${{ inputs.upload_artifacts_to_s3 }} + id: store-s3-upload-uri + run: echo "s3_upload_uri=${INPUTS_S3_UPLOAD_URI}" >> "$GITHUB_OUTPUT" + env: + INPUTS_S3_UPLOAD_URI: ${{ inputs.s3_upload_uri }} diff --git a/.github/workflows/pyrefly.yml b/.github/workflows/pyrefly.yml deleted file mode 100644 index 2287751ea066..000000000000 --- a/.github/workflows/pyrefly.yml +++ /dev/null @@ -1,33 +0,0 @@ -name: Pyrefly type check (non-blocking) - -on: - push: - branches: - - main - pull_request: - branches: - - main - -permissions: {} - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} - # Don't cancel in-progress jobs for main branches. - cancel-in-progress: ${{ github.ref != 'main' }} - -jobs: - pyrefly: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - persist-credentials: false - - name: Set up Python 3.12 - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 - with: - python-version: 3.12 - - run: python -m pip install pre-commit - - name: Run pyrefly check - run: pre-commit run pyrefly-check --hook-stage=manual --show-diff-on-failure --color=always --all-files - # This is expected to fail; we set continue-on-error so the workflow will be marked green. - continue-on-error: true diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml index b6cf6c9d5037..22cd877d063e 100644 --- a/.github/workflows/pytest_tpu.yml +++ b/.github/workflows/pytest_tpu.yml @@ -93,7 +93,11 @@ jobs: gcs_download_uri: ${{ inputs.gcs_download_uri }} - name: Install Python dependencies run: | - $JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt -r build/collect-profile-requirements.txt + $JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt + # xprof depends on cffi, which doesn't support Python 3.13 free-threaded. + if [[ "${JAXCI_HERMETIC_PYTHON_VERSION}" != "3.13-nogil" ]]; then + $JAXCI_PYTHON -m uv pip install -r build/collect-profile-requirements.txt + fi - name: Set up libtpu wheels run: | if [[ "${INPUTS_LIBTPU_VERSION_TYPE}" == "nightly" ]]; then @@ -124,10 +128,6 @@ jobs: - name: Run Pytest TPU tests timeout-minutes: ${{ github.event_name == 'pull_request' && 30 || 210 }} run: | - if [[ ${INPUTS_PYTHON} == "3.13-nogil" ]]; then - echo "Uninstalling xprof as it is not compatible with python 3.13t." - $JAXCI_PYTHON -m uv pip uninstall xprof - fi ./ci/run_pytest_tpu.sh env: INPUTS_PYTHON: ${{ inputs.python }} diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index 1065e9eb2aea..7eaf4a58b2f2 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -46,7 +46,7 @@ jobs: dist_docker \ --image-tag $TEST_IMAGE - name: Archive jax wheels - uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 with: name: rocm_jax_r${{ env.ROCM_VERSION }}_py${{ env.PYTHON_VERSION }}_id${{ github.run_id }} path: ${{ env.WORKSPACE_DIR }}/dist/*.whl diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index e9b1f174f062..1c4610f14418 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -1,6 +1,6 @@ # CI - Wheel Tests (Continuous) # -# This workflow builds JAX artifacts and runs CPU/TPU/CUDA tests. +# This workflow builds JAX artifacts and runs CPU/TPU/CUDA/ROCm tests. # # It orchestrates the following: # 1. build-jaxlib-artifact: Calls the `build_artifacts.yml` workflow to build jaxlib and @@ -22,15 +22,26 @@ # that was built in the previous step and runs TPU tests. # 9. run-bazel-test-tpu: Calls the `bazel_test_tpu.yml` workflow which # runs Bazel TPU tests with py_import. +# 10. build-rocm-artifacts: Calls the `build_rocm_artifacts.yml` workflow to build ROCm plugin/pjrt +# wheels and uploads them to an S3 bucket. +# 11. run-pytest-rocm: Calls the `pytest_rocm.yml` workflow which downloads the jaxlib and +# ROCm artifacts and runs the ROCm tests. +# 12. run-bazel-test-rocm: Calls the `bazel_rocm.yml` workflow which runs the ROCm Bazel tests. name: CI - Wheel Tests (Continuous) permissions: - contents: read + id-token: write + contents: read + actions: read on: schedule: - cron: "0 */3 * * *" # Run once every 3 hours workflow_dispatch: # allows triggering the workflow run manually + pull_request: + paths: + - '.github/workflows/wheel_tests_continuous.yml' + - '.github/workflows/build_rocm_artifacts.yml' concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} @@ -285,4 +296,73 @@ jobs: gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} build_jaxlib: "wheel" build_jax: "wheel" - clone_main_xla: 1 \ No newline at end of file + clone_main_xla: 1 + + build-rocm-artifacts: + uses: ./.github/workflows/build_rocm_artifacts.yml + secrets: inherit + permissions: + id-token: write + contents: read + actions: read + strategy: + fail-fast: false + matrix: + runner: ["linux-x86-64-1gpu-amd"] + artifact: ["jax-rocm-plugin", "jax-rocm-pjrt"] + python: ["3.11"] + rocm-version: ["7"] + name: "Build ${{ format('{0}', 'ROCm') }} artifacts" + with: + runner: ${{ matrix.runner }} + artifact: ${{ matrix.artifact }} + python: ${{ matrix.python }} + rocm-version: ${{ matrix.rocm-version }} + clone_main_xla: 1 + upload_artifacts_to_s3: true + s3_upload_uri: 's3://jax-ci-amd/rocm-wheels/wheel-tests-continuous/${{ github.run_number }}/${{ github.run_attempt }}' + + run-pytest-rocm: + if: ${{ !cancelled() }} + needs: [build-jax-artifact, build-jaxlib-artifact, build-rocm-artifacts] + uses: ./.github/workflows/pytest_rocm.yml + strategy: + fail-fast: false + matrix: + runner: ["linux-x86-64-1gpu-amd", "linux-x86-64-4gpu-amd", "linux-x86-64-8gpu-amd"] + python: ["3.11"] + rocm: [ + {version: "7.2.0", tag: "rocm720"}, + ] + name: "Pytest ROCm (JAX artifacts version = ${{ format('{0}', 'head') }})" + with: + runner: ${{ matrix.runner }} + python: ${{ matrix.python }} + rocm-version: ${{ matrix.rocm.version }} + rocm-tag: ${{ matrix.rocm.tag }} + jaxlib-version: "head" + gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} + + run-bazel-test-rocm: + if: ${{ !cancelled() }} + needs: [build-jax-artifact, build-jaxlib-artifact, build-rocm-artifacts] + uses: ./.github/workflows/bazel_rocm.yml + strategy: + fail-fast: false + matrix: + runner: ["linux-x86-64-4gpu-amd", "linux-x86-64-8gpu-amd"] + python: ["3.11"] + rocm-version: ["7"] + enable-x64: [0] + name: "Bazel ROCm tests (JAX artifacts version = ${{ format('{0}', 'head') }})" + with: + runner: ${{ matrix.runner }} + python: ${{ matrix.python }} + rocm-version: ${{ matrix.rocm-version }} + enable-x64: ${{ matrix.enable-x64 }} + gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} + build_jaxlib: "false" + build_jax: "false" + jaxlib-version: "head" + run_multiaccelerator_tests: "false" + clone_main_xla: 1 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 78e9ce552d4d..aee5bc1c8429 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -35,14 +35,23 @@ repos: hooks: - id: ruff -- repo: https://github.com/pre-commit/mirrors-mypy - rev: a66e98df7b4aeeb3724184b332785976d062b92e # frozen: v1.19.1 +- repo: https://github.com/facebook/pyrefly-pre-commit + rev: 30778c6e83a71508a62b7297f8b22660ce4496fc # frozen: v0.55.0 hooks: - - id: mypy - files: (jax/|tests/typing_test\.py) - exclude: jax/_src/basearray.py|jax/numpy/__init__.py|jax/nn/__init__.py|jaxlib/_jax/.* # Use pyi instead - additional_dependencies: [types-requests==2.31.0, numpy~=2.3.0, scipy-stubs] - args: [--config=pyproject.toml] + - id: pyrefly-check + name: Pyrefly (type checking) + pass_filenames: false + additional_dependencies: + - absl-py==2.4.0 + - types-requests~=2.32.0 + - numpy~=2.4.0 + - ml_dtypes~=0.5.0 + - opt-einsum~=3.4.0 + - scipy-stubs + - --pre + - --extra-index-url + - https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ + - jaxlib - repo: https://github.com/mwouts/jupytext rev: 8ed836db64ad5d304f2315e6bfd9049c9142e190 # frozen: v1.16.4 @@ -51,17 +60,6 @@ repos: files: docs/ args: [--sync] -# This is a manual-only pre-commit hook to run pyrefly type checks. To run it: -# $ pre-commit run --hook-stage manual pyrefly-check --all-files -- repo: https://github.com/facebook/pyrefly-pre-commit - rev: 0ed71f5d10c035e02f24a220058b39070d165142 # frozen: v0.54.0 - hooks: - - id: pyrefly-check - name: Pyrefly (type checking) - pass_filenames: false - additional_dependencies: [absl-py==2.4.0, types-requests~=2.32.0, numpy~=2.4.0, ml_dtypes~=0.5.0, scipy-stubs] - stages: [manual] - - repo: local hooks: - id: check-copyright diff --git a/MODULE.bazel b/MODULE.bazel index 7714e6b019b9..becd6f872369 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -6,13 +6,13 @@ module(name = "jax") bazel_dep(name = "abseil-cpp", version = "20250814.1", repo_name = "com_google_absl") bazel_dep(name = "bazel_features", version = "1.36.0") -bazel_dep(name = "bazel_skylib", version = "1.8.1") +bazel_dep(name = "bazel_skylib", version = "1.8.2") bazel_dep(name = "flatbuffers", version = "25.2.10", repo_name = "com_github_google_flatbuffers") bazel_dep(name = "grpc", version = "1.78.0", repo_name = "com_github_grpc_grpc") bazel_dep(name = "platforms", version = "1.0.0") bazel_dep(name = "pybind11_bazel", version = "2.13.6") bazel_dep(name = "rules_cc", version = "0.2.9") -bazel_dep(name = "rules_python", version = "1.6.3") +bazel_dep(name = "rules_python", version = "1.8.4") # TODO: use a released version when available bazel_dep(name = "rules_ml_toolchain") @@ -27,9 +27,9 @@ archive_override( bazel_dep(name = "xla") archive_override( module_name = "xla", - integrity = "sha256-4C4sa4TtEsEo8EqOsc3uUV4ceXDHr0EnZ3EvhBvtHBs=", - strip_prefix = "xla-ba5bbceae1ff6c0f03f0234ba6beadbcdae74635", - urls = ["https://github.com/openxla/xla/archive/ba5bbceae1ff6c0f03f0234ba6beadbcdae74635.tar.gz"], + integrity = "sha256-Tw4m9BWT5Dgsnr6SEQHpDOfMSn95O6gBvYJswJJdGCI=", + strip_prefix = "xla-b0037055f03bf51d364cba09e94278963aec5bcf", + urls = ["https://github.com/openxla/xla/archive/b0037055f03bf51d364cba09e94278963aec5bcf.tar.gz"], ) # TODO: upstream, otherwise we have to duplicate the patches in jax @@ -99,8 +99,9 @@ single_version_override( "//third_party/py:rules_python_pip_version.patch", "//third_party/py:rules_python_freethreaded.patch", "//third_party/py:rules_python_versions.patch", + "//third_party/py:rules_python_scope.patch", ], - version = "1.6.3", + version = "1.8.4", ) ### Toolchains @@ -129,7 +130,7 @@ pip = use_extension("@rules_python//python/extensions:pip.bzl", "pip") python_version = python_version, requirements_lock = "//build:requirements_lock_{}.txt".format(python_version.replace(".", "_")), whl_modifications = { - "@pypi_mods//:numpy.json": "numpy", # @pypi_mods is defined in XLA's MODULE.bazel + "@//build:numpy.json": "numpy", # @pypi_mods is defined in XLA's MODULE.bazel }, ) for python_version in [ "3.11", diff --git a/build/BUILD.bazel b/build/BUILD.bazel index 6b14bddaf941..08fdbfa3575f 100644 --- a/build/BUILD.bazel +++ b/build/BUILD.bazel @@ -42,6 +42,14 @@ COMBOS = [ ("_ft", FREETHREADING_REQUIREMENTS), ] +exports_files(["numpy.json"]) + +filegroup( + name = "numpy_json", + srcs = ["numpy.json"], + visibility = ["//visibility:public"], +) + [ compile_pip_requirements( name = "requirements" + suffix, diff --git a/build/build.py b/build/build.py index 478ae0586c63..c9f84305a683 100755 --- a/build/build.py +++ b/build/build.py @@ -697,7 +697,7 @@ async def main(): if ("plugin" in wheel or "pjrt" in wheel) and "jax" not in wheel: wheel = "jax-" + wheel - wheel_build_command = copy.deepcopy(bazel_command_base) + wheel_build_command = copy.deepcopy(wheel_build_command_base) if "cuda" in args.wheels: wheel_build_command.append("--config=cuda_libraries_from_stubs") diff --git a/build/numpy.json b/build/numpy.json new file mode 100644 index 000000000000..47d761524f2b --- /dev/null +++ b/build/numpy.json @@ -0,0 +1,10 @@ +{ + "additive_build_content": "cc_library(\n name = \"numpy_headers\",\n hdrs = glob([\"numpy/_core/include/**/*.h\", \"numpy/core/include/**/*.h\", \"site-packages/numpy/_core/include/**/*.h\", \"site-packages/numpy/core/include/**/*.h\"], allow_empty = True),\n includes = [\"numpy/_core/include\", \"numpy/core/include\", \"site-packages/numpy/_core/include\", \"site-packages/numpy/core/include\"],\n visibility = [\"//visibility:public\"],\n)", + "data": [], + "srcs": [], + "common_extra_deps": [], + "copy_files": {}, + "copy_executables": {}, + "data_exclude_glob": [], + "srcs_exclude_glob": [] +} \ No newline at end of file diff --git a/build/rocm/rocm.bazelrc b/build/rocm/rocm.bazelrc index 9935cdce944c..ef27ec30b057 100644 --- a/build/rocm/rocm.bazelrc +++ b/build/rocm/rocm.bazelrc @@ -20,6 +20,9 @@ common:rocm --copt=-Qunused-arguments # Used for @xla//build_tools/rocm:parallel_gpu_execute common:rocm --legacy_external_runfiles=true +build:rocm_release_wheel --config=rocm +build:rocm_release_wheel --@local_config_rocm//rocm:rocm_path_type=link_only + test:rocm --test_timeout=920,2400,7200,9600 test:rocm --flaky_test_attempts=3 test:rocm --test_verbose_timeout_warnings @@ -57,3 +60,6 @@ test:rocm_rbe --strategy=TestRunner=remote,local test:rocm_rbe --worker_sandboxing=false test:rocm_rbe --repo_env=REMOTE_GPU_TESTING=1 +# Used for rules_python bootstrap 1.8.0+ +build --@rules_python//python/config_settings:bootstrap_impl=script --repo_env=RULES_PYTHON_ENABLE_PIPSTAR=0 +test --@rules_python//python/config_settings:bootstrap_impl=script --repo_env=RULES_PYTHON_ENABLE_PIPSTAR=0 diff --git a/ci/run_bazel_test_rocm_rbe.sh b/ci/run_bazel_test_rocm_rbe.sh index 82a5d87d9b8a..a01be683d616 100755 --- a/ci/run_bazel_test_rocm_rbe.sh +++ b/ci/run_bazel_test_rocm_rbe.sh @@ -40,10 +40,21 @@ for arg in "$@"; do fi done +# Set up the build environment which sets XLA_DIR if needed. +# Useful for matching the XLA version to the JAX version and debugging/testing. +source "ci/utilities/setup_build_environment.sh" + +OVERRIDE_XLA_REPO="" +if [[ "$JAXCI_CLONE_MAIN_XLA" == 1 ]]; then + OVERRIDE_XLA_REPO="--override_repository=xla=${JAXCI_XLA_GIT_DIR}" +fi + + bazel --bazelrc=build/rocm/rocm.bazelrc test \ --config=rocm_rbe \ --config=rocm \ --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + $OVERRIDE_XLA_REPO \ --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ --test_output=errors \ --test_env=TF_CPP_MIN_LOG_LEVEL=0 \ diff --git a/ci/run_bazel_test_tpu.sh b/ci/run_bazel_test_tpu.sh index cb804d34f1b6..9774630e8f31 100755 --- a/ci/run_bazel_test_tpu.sh +++ b/ci/run_bazel_test_tpu.sh @@ -91,8 +91,8 @@ if [[ "$JAXCI_RUN_FULL_TPU_TEST_SUITE" == "1" ]]; then # Run single-accelerator tests in parallel bazel test \ --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ - --@rules_python//python/config_settings:py_freethreaded="$FREETHREADED_FLAG_VALUE" \ $OVERRIDE_XLA_REPO \ + --@rules_python//python/config_settings:py_freethreaded="$FREETHREADED_FLAG_VALUE" \ --config=ci_linux_x86_64 \ --config=ci_rbe_cache \ --//jax:build_jaxlib=$JAXCI_BUILD_JAXLIB \ @@ -123,8 +123,8 @@ if [[ "$JAXCI_RUN_FULL_TPU_TEST_SUITE" == "1" ]]; then # Run multi-accelerator across all chips bazel test \ --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ - --@rules_python//python/config_settings:py_freethreaded="$FREETHREADED_FLAG_VALUE" \ $OVERRIDE_XLA_REPO \ + --@rules_python//python/config_settings:py_freethreaded="$FREETHREADED_FLAG_VALUE" \ --config=ci_linux_x86_64 \ --config=ci_rbe_cache \ --//jax:build_jaxlib=$JAXCI_BUILD_JAXLIB \ @@ -152,8 +152,8 @@ else # Run single-accelerator tests in parallel bazel test \ --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ - --@rules_python//python/config_settings:py_freethreaded="$FREETHREADED_FLAG_VALUE" \ $OVERRIDE_XLA_REPO \ + --@rules_python//python/config_settings:py_freethreaded="$FREETHREADED_FLAG_VALUE" \ --config=ci_linux_x86_64 \ --config=ci_rbe_cache \ --//jax:build_jaxlib=$JAXCI_BUILD_JAXLIB \ diff --git a/ci/utilities/run_auditwheel.sh b/ci/utilities/run_auditwheel.sh index 304dd1ab1792..3d41d1f0955f 100755 --- a/ci/utilities/run_auditwheel.sh +++ b/ci/utilities/run_auditwheel.sh @@ -18,7 +18,7 @@ # Get a list of all the wheels in the output directory. Only look for wheels # that need to be verified for manylinux compliance. -WHEELS=$(find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jaxlib*whl" -o -name "*jax*cuda*pjrt*whl" -o -name "*jax*cuda*plugin*whl" \)) +WHEELS=$(find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jaxlib*whl" -o -name "*jax*cuda*pjrt*whl" -o -name "*jax*cuda*plugin*whl" -o -name "*jax*rocm*pjrt*whl" -o -name "*jax*rocm*plugin*whl" \)) if [[ -z "$WHEELS" ]]; then echo "ERROR: No wheels found under $JAXCI_OUTPUT_DIR" @@ -37,10 +37,11 @@ for wheel in $WHEELS; do wheel_name=$(basename $wheel) OUTPUT=${OUTPUT_FULL//${wheel_name}/} - # If a wheel is manylinux_2_27 or manylinux2014 compliant, `auditwheel show` - # will return platform tag as manylinux_2_27 or manylinux_2_17 respectively. - # manylinux2014 is an alias for manylinux_2_17. - if echo "$OUTPUT" | grep -q "manylinux_2_27"; then + # If a wheel is manylinux compliant, `auditwheel show` will return the + # platform tag. manylinux2014 is an alias for manylinux_2_17. + if echo "$OUTPUT" | grep -q "manylinux_2_28"; then + printf "\n$wheel_name is manylinux_2_28 compliant.\n" + elif echo "$OUTPUT" | grep -q "manylinux_2_27"; then printf "\n$wheel_name is manylinux_2_27 compliant.\n" # jax_cudaX_plugin...aarch64.whl is consistent with tag: manylinux_2_26_aarch64" elif echo "$OUTPUT" | grep -q "manylinux_2_26"; then @@ -49,7 +50,7 @@ for wheel in $WHEELS; do printf "\n$wheel_name is manylinux2014 compliant.\n" else echo "$OUTPUT_FULL" - printf "\n$wheel_name is NOT manylinux_2_27 or manylinux2014 compliant.\n" + printf "\n$wheel_name is NOT manylinux compliant.\n" exit 1 fi done \ No newline at end of file diff --git a/docs/contributing.md b/docs/contributing.md index 40334bb9599a..4408b388432b 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -186,7 +186,7 @@ possible. The `git rebase -i` command might be useful to this end. ### Linting and type-checking -JAX uses [mypy](https://mypy.readthedocs.io/) and +JAX uses [Pyrefly](https://pyrefly.org/) and [ruff](https://docs.astral.sh/ruff/) to statically test code quality; the easiest way to run these checks locally is via the [pre-commit](https://pre-commit.com/) framework: diff --git a/docs/developer.md b/docs/developer.md index 8244a299187a..75063ad8232c 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -678,20 +678,12 @@ JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 p ## Type checking -We use `mypy` to check the type hints. To run `mypy` with the same configuration as the +We use `pyrefly` to check the type hints. To run `pyrefly` with the same configuration as the github CI checks, you can use the [pre-commit](https://pre-commit.com/) framework: ``` pip install pre-commit -pre-commit run mypy --all-files -``` - -Because `mypy` can be somewhat slow when checking all files, it may be convenient to -only check files you have modified. To do this, first stage the changes (i.e. `git add` -the changed files) and then run this before committing the changes: - -``` -pre-commit run mypy +pre-commit run pyrefly-check --all-files ``` ## Linting diff --git a/docs/pallas/gpu/reference.md b/docs/pallas/gpu/reference.md index 99db4d2b1354..ecca7369a3d2 100644 --- a/docs/pallas/gpu/reference.md +++ b/docs/pallas/gpu/reference.md @@ -284,11 +284,11 @@ wrong results. There are a few useful layouts we have defined for you so far: * `plgpu.Layout.WGMMA`, which is the layout in which the Hopper-generation TensorCore expects the MMA accumulator or 16-bit input operands to have in registers. -* `plgpu.Layout.WGMMA_ROW`, which is the layout obtained after the above after reducing - it along the rows. Re-broadcasting the rows is free and will produce a value with `WGMMA` - layout. -* `plgpu.Layout.WGMMA_COL`, which is an analogue of the one above, only reduced along - columns instead of rows. + Calling `plgpu.Layout.WGMMA.reduce(axes)` gives a layout suitable for values + reduced along the specified axes, e.g. `reduce(1)` for a row result and + `reduce(0)` for a column result. Re-broadcasting the reduced dimensions is + free and produces a value with + `WGMMA` layout. * `plgpu.Layout.WG_STRIDED`, where the value is partitioned equally among the 128 CUDA lanes making up a Pallas thread. The consecutive elements (after vectorization) are assigned to the lanes in a round-robin fashion. Very simple and effective when diff --git a/docs/pallas/tpu/sparsecore.ipynb b/docs/pallas/tpu/sparsecore.ipynb index e13de9d6ae24..40aaf3aaf368 100644 --- a/docs/pallas/tpu/sparsecore.ipynb +++ b/docs/pallas/tpu/sparsecore.ipynb @@ -562,7 +562,7 @@ " ),\n", " out_specs=pl.BlockSpec((gather_window_size, value_dim), lambda i: (i, 0)),\n", " compiler_params=pltpu.CompilerParams(\n", - " kernel_type=pltpu.KernelType.SC_VECTOR_SUBCORE,\n", + " kernel_type=pltpu.CoreType.SC_VECTOR_SUBCORE,\n", " dimension_semantics=(pltpu.PARALLEL,),\n", " ),\n", " )\n", diff --git a/docs/pallas/tpu/sparsecore.md b/docs/pallas/tpu/sparsecore.md index c2a9135aff60..56b92bb5cdc7 100644 --- a/docs/pallas/tpu/sparsecore.md +++ b/docs/pallas/tpu/sparsecore.md @@ -320,7 +320,7 @@ def gather_add_one(x, indices): ), out_specs=pl.BlockSpec((gather_window_size, value_dim), lambda i: (i, 0)), compiler_params=pltpu.CompilerParams( - kernel_type=pltpu.KernelType.SC_VECTOR_SUBCORE, + kernel_type=pltpu.CoreType.SC_VECTOR_SUBCORE, dimension_semantics=(pltpu.PARALLEL,), ), ) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 12069572e453..c49dfed4b5c0 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -1037,21 +1037,23 @@ def lin(self, nzs_in, *primals): primals_out, f_lin = api.linearize(self.traced, *primals) return primals_out, primals - def linearized(self, primals, *tangents): + def linearized(self, primals, *tangents): # pyrefly: ignore[bad-param-name-override] _, f_lin = api.linearize(self.traced, *primals) return f_lin(*tangents) class CheckpointName(VJPHiPrimitive): + name: str + def __init__(self, name, aval): self.in_avals = aval, self.out_aval = aval self.params = dict(name=name) super().__init__() - def expand(self, x): + def expand(self, x): # pyrefly: ignore[bad-override] return x - def remat(self, policy, x): + def remat(self, policy, x): # pyrefly: ignore[bad-override] saveable = self.name in policy rem = partial(primal_left_tangent_right, x) if saveable else lambda x: x return x, rem diff --git a/jax/_src/ad_util.py b/jax/_src/ad_util.py index 4e3dd2341818..677f8968c8b4 100644 --- a/jax/_src/ad_util.py +++ b/jax/_src/ad_util.py @@ -32,8 +32,9 @@ map = safe_map def add_jaxvals(x: ArrayLike, y: ArrayLike) -> Array: + from jax._src.hijax import HiType # pytype: disable=import-error ty = typeof(x) - if hasattr(ty, 'vspace_add'): # TODO(mattjj,dougalm): revise away hasattr + if isinstance(ty, HiType): return ty.vspace_add(x, y) x, y = core.standard_insert_pvary(x, y) return add_jaxvals_p.bind(x, y) @@ -52,8 +53,9 @@ def add_abstract(x, y): return x def zeros_like_aval(aval: core.AbstractValue) -> Array: - if hasattr(aval, 'vspace_zero'): # TODO(mattjj,dougalm): revise away hasattr - return aval.vspace_zero() + from jax._src.hijax import HiType # pytype: disable=import-error + if isinstance(aval, HiType): + return aval.vspace_zero() # pytype: disable=attribute-error return aval_zeros_likers[type(aval)](aval) aval_zeros_likers: dict[type, Callable[[Any], Array]] = {} @@ -81,7 +83,7 @@ def p2tz(primal_value): return Zero(typeof(primal_value).to_tangent_aval()) def p2cz(primal_value): - return Zero(typeof(primal_value).to_cotangent_aval()) + return Zero(typeof(primal_value).to_ct_aval()) def _stop_gradient_impl(x: T) -> T: diff --git a/jax/_src/api.py b/jax/_src/api.py index 3cff0f5f0ce8..a5fbf7c9480e 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -675,7 +675,7 @@ def fwd(*args, **kwargs): f = lu.wrap_init(fun, params=kwargs, debug_info=dbg) f_partial, dyn_args = argnums_partial( f, argnums, args, require_static_args_hashable=False) - return _vjp(f_partial, *dyn_args, has_aux=has_aux) # type: ignore + return _vjp(f_partial, *dyn_args, has_aux=has_aux) def bwd(f_vjp, outgrad): g = f_vjp(outgrad) g = g[0] if isinstance(argnums, int) else g @@ -1160,7 +1160,7 @@ def vmap(fun: F, # rather than raising an error. https://github.com/jax-ml/jax/issues/2367 in_axes = tuple(in_axes) - from jax._src import hijax # type: ignore + from jax._src import hijax # pytype: disable=import-error if not (in_axes is None or type(in_axes) in {int, tuple, *batching.spec_types} or isinstance(in_axes, hijax.MappingSpec)): raise TypeError("vmap in_axes must be an int, None, or a tuple of entries corresponding " @@ -1845,8 +1845,8 @@ def cache_miss(*args, **kwargs): in_handler=in_handler, out_handler=out_handler, out_pytree_def=out_pytree_def, - input_devices=in_handler.local_devices, - input_indices=in_handler.input_indices, + input_devices=in_handler.local_devices, # pyrefly: ignore[bad-argument-type] + input_indices=in_handler.input_indices, # pyrefly: ignore[bad-argument-type] input_array_shardings=in_handler.in_shardings, out_avals=out_handler.out_avals, out_array_shardings=out_array_shardings, @@ -2087,6 +2087,7 @@ def linearize(fun: Callable, *primals, has_aux: bool = False (in_tree, out_tree), out_pvals), consts) if has_aux: [aux] = maybe_aux + assert aux_tree is not None return out_primal_py, lifted_jvp, tree_unflatten(aux_tree, aux) else: [] = maybe_aux @@ -2219,8 +2220,8 @@ def _vjp(fun, *primals, has_aux=False): out_known = [pval.is_known() for pval in out_pvals] id_map = {id(x): i for i, x in enumerate(primals_flat)} used, opaque_residuals = set(), [] - spec = [used.add(id(r)) or RSpec(id_map[id(r)], True) if id(r) in id_map else # type: ignore - RSpec(opaque_residuals.append(r) or (len(opaque_residuals) - 1), False) # type: ignore + spec = [used.add(id(r)) or RSpec(id_map[id(r)], True) if id(r) in id_map else + RSpec(opaque_residuals.append(r) or (len(opaque_residuals) - 1), False) for r in residuals] args_res = tuptree_map(lambda x: x if id(x) in used else NotNeeded(), in_tree, primals_flat) @@ -2321,7 +2322,7 @@ def _vjp_check_ct_avals(cts, primal_avals): # TODO(mattjj): improve this error by flattening with keys in the first place for ct, aval in zip(cts, primal_avals): ct_aval = typeof(ct) - ct_aval_expected = aval.to_cotangent_aval() + ct_aval_expected = aval.to_ct_aval() if (not core.typecompat(ct_aval, ct_aval_expected) and not _temporary_dtype_exception(ct_aval, ct_aval_expected)): raise ValueError( @@ -2351,7 +2352,7 @@ class VJP: out_tree: PyTreeDef args_res: list[Any] opaque_residuals: list[Any] - jaxpr = property(lambda self: self.fun.args[2]) # type: ignore + jaxpr = property(lambda self: self.fun.args[2]) # pytype: disable=attribute-error def __call__(self, out_ct, *extra_args): if extra_args: diff --git a/jax/_src/array.py b/jax/_src/array.py index 1ce52849a9de..03c374069b71 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -334,7 +334,7 @@ def __format__(self, format_spec): else: return repr(self) - def __getitem__(self, idx): + def __getitem__(self, idx): # pyrefly: ignore[bad-param-name-override] from jax._src.lax import lax # pytype: disable=import-error from jax._src.numpy import indexing # pytype: disable=import-error self._check_if_deleted() @@ -360,7 +360,7 @@ def __getitem__(self, idx): dims = tuple(i for i, x in enumerate(cidx) if isinstance(x, int)) # Squeeze on committed arrays to avoid data movement to shard 0. out = lax.squeeze(out, dimensions=dims) - + assert isinstance(out, ArrayImpl) return ArrayImpl( out.aval, sharding, [out], committed=False, _skip_checks=True) @@ -372,7 +372,7 @@ def __iter__(self): else: assert self.is_fully_replicated or self.is_fully_addressable if dispatch.is_single_device_sharding(self.sharding) or self.is_fully_replicated: - return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack()) + return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack()) # pyrefly: ignore[missing-attribute] elif isinstance(self.sharding, PmapSharding): return (self[i] for i in range(self.shape[0])) else: @@ -433,10 +433,10 @@ def is_fully_addressable(self) -> bool: """ return self.sharding.is_fully_addressable - def __array__(self, dtype=None, context=None, copy=None): + def __array__(self, dtype=None, context=None, copy=None): # pyrefly: ignore[bad-override] # copy argument is supported by np.asarray starting in numpy 2.0 kwds = {} if copy is None else {'copy': copy} - return np.asarray(self._value, dtype=dtype, **kwds) + return np.asarray(self._value, dtype=dtype, **kwds) # pyrefly: ignore[no-matching-overload] def __dlpack__(self, *, stream: int | Any | None = None, max_version: tuple[int, int] | None = None, @@ -464,10 +464,10 @@ def __dlpack_device__(self) -> tuple[enum.Enum, int]: from jax._src.dlpack import DLDeviceType # pytype: disable=import-error # pylint: disable=g-import-not-at-top - if self.platform() == "cpu": + if self.platform() == "cpu": # pyrefly: ignore[missing-attribute] return DLDeviceType.kDLCPU, 0 - elif self.platform() == "gpu": + elif self.platform() == "gpu": # pyrefly: ignore[missing-attribute] platform_version = _get_device(self).client.platform_version if "cuda" in platform_version: dl_device_type = DLDeviceType.kDLCUDA @@ -533,7 +533,7 @@ def device_buffers(self): def addressable_data(self, index: int) -> ArrayImpl: self._check_if_deleted() if self.is_fully_replicated: - return self._fully_replicated_shard() + return self._fully_replicated_shard() # pyrefly: ignore[missing-attribute] return self._arrays[index] @functools.cached_property @@ -550,7 +550,7 @@ def format(self): if self.is_deleted(): return Format(None, self.sharding) try: - return Format(Layout.from_pjrt_layout(self._pjrt_layout), + return Format(Layout.from_pjrt_layout(self._pjrt_layout), # pyrefly: ignore[missing-attribute] self.sharding) except _jax.JaxRuntimeError as e: msg, *_ = e.args @@ -586,7 +586,7 @@ def delete(self): return for buf in self._arrays: buf.delete() - self._arrays = None + self._arrays = None # pyrefly: ignore[bad-assignment] self._npy_value = None @use_cpp_method() @@ -760,11 +760,12 @@ def make_array_from_callback( raise TypeError( "`Layout.AUTO` cannot be used in place of a device-local" f" layout when calling `jax.make_array_from_callback`. Got {sharding}") - sharding = sharding.sharding if isinstance(sharding, Format) else sharding - if not isinstance(sharding, Sharding): + processed_sharding = sharding.sharding if isinstance(sharding, Format) else sharding + if not isinstance(processed_sharding, Sharding): raise TypeError( - f"sharding should be an instance of `jax.sharding`. Got {sharding} of" - f" type {type(sharding)}") + f"sharding should be an instance of `jax.sharding`. Got {processed_sharding} of" + f" type {type(processed_sharding)}") + sharding = processed_sharding def get_data( index: Index | None, @@ -789,7 +790,7 @@ def get_data( return r if sharding.is_fully_replicated: - devices = list(sharding._internal_device_list.addressable_device_list) # type: ignore + devices = list(sharding._internal_device_list.addressable_device_list) # Only compute data once. per_device_values = [get_data((slice(None),) * len(shape))] * len(devices) else: @@ -830,7 +831,7 @@ def get_data( ) if dll is not None: - devices = [Format(dll, SingleDeviceSharding(d)) for d in devices] # type: ignore + devices = [Format(dll, SingleDeviceSharding(d)) for d in devices] # pxla.batched_device_put doesn't support Layout... Take the slow route arrays = api.device_put(per_device_values, devices) return ArrayImpl(aval, sharding, arrays, committed=True) @@ -1102,11 +1103,11 @@ def make_array_from_single_device_arrays( if dtypes.issubdtype(aval.dtype, dtypes.extended): return aval.dtype._rules.make_sharded_array(aval, sharding, arrays, committed=True) - arrays = list(arrays) if isinstance(arrays, tuple) else arrays + arrays = list(arrays) if isinstance(arrays, tuple) else arrays # pyrefly: ignore[no-matching-overload] # pyrefly#2607 # TODO(phawkins): ideally the cast() could be checked. try: return ArrayImpl(aval, sharding, cast(Sequence[ArrayImpl], arrays), - committed=True) + committed=True) except TypeError: if not isinstance(arrays, list): raise TypeError("jax.make_array_from_single_device_arrays `arrays` " @@ -1155,7 +1156,7 @@ def as_slice_indices(arr: Any, idx: Index) -> tuple[ removed_dims: list[int] = [] tuple_idx = idx if isinstance(idx, tuple) else (idx,) - for dim, sub_idx in enumerate(tuple_idx): + for dim, sub_idx in enumerate(tuple_idx): # pyrefly: ignore[bad-argument-type] if isinstance(sub_idx, int): start_indices[dim] = sub_idx limit_indices[dim] = sub_idx + 1 @@ -1333,5 +1334,5 @@ def _token_global_result_handler(global_aval, out_sharding, committed): core.get_token_aval(), out_sharding, committed) def wrapper(array): return core.Token(array) - return array_handler.wrap(wrapper) # type: ignore + return array_handler.wrap(wrapper) pxla.global_result_handlers[core.AbstractToken] = _token_global_result_handler diff --git a/jax/_src/basearray.pyi b/jax/_src/basearray.pyi index 1d6cb167a026..3ceb7ac6b2b1 100644 --- a/jax/_src/basearray.pyi +++ b/jax/_src/basearray.pyi @@ -42,7 +42,8 @@ PrecisionLike = Any class Array: - aval: Any + @property + def aval(self) -> Any: ... @property def dtype(self) -> np.dtype: ... @@ -82,8 +83,8 @@ class Array: # these return bool for object, so ignore override errors. def __lt__(self, other: ArrayLike) -> Array: ... def __le__(self, other: ArrayLike) -> Array: ... - def __eq__(self, other: ArrayLike) -> Array: ... # type: ignore[override] - def __ne__(self, other: ArrayLike) -> Array: ... # type: ignore[override] + def __eq__(self, other: ArrayLike) -> Array: ... # pyrefly: ignore[bad-override] + def __ne__(self, other: ArrayLike) -> Array: ... # pyrefly: ignore[bad-override] def __gt__(self, other: ArrayLike) -> Array: ... def __ge__(self, other: ArrayLike) -> Array: ... @@ -111,15 +112,15 @@ class Array: def __xor__(self, other: ArrayLike) -> Array: ... def __or__(self, other: ArrayLike) -> Array: ... - def __radd__(self, other: ArrayLike) -> Array: ... # type: ignore[misc] - def __rsub__(self, other: ArrayLike) -> Array: ... # type: ignore[misc] - def __rmul__(self, other: ArrayLike) -> Array: ... # type: ignore[misc] + def __radd__(self, other: ArrayLike) -> Array: ... + def __rsub__(self, other: ArrayLike) -> Array: ... + def __rmul__(self, other: ArrayLike) -> Array: ... def __rmatmul__(self, other: ArrayLike) -> Array: ... - def __rtruediv__(self, other: ArrayLike) -> Array: ... # type: ignore[misc] + def __rtruediv__(self, other: ArrayLike) -> Array: ... def __rfloordiv__(self, other: ArrayLike) -> Array: ... def __rmod__(self, other: ArrayLike) -> Array: ... def __rdivmod__(self, other: ArrayLike) -> Array: ... - def __rpow__(self, other: ArrayLike) -> Array: ... # type: ignore[misc] + def __rpow__(self, other: ArrayLike) -> Array: ... def __rlshift__(self, other: ArrayLike) -> Array: ... def __rrshift__(self, other: ArrayLike) -> Array: ... def __rand__(self, other: ArrayLike) -> Array: ... diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 1980ab5066be..71c499db8e0c 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -895,7 +895,7 @@ def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function ifrt_callback = _wrapped_callback ctx.module_context.add_host_callback(ifrt_callback) index = np.uint64(len(ctx.module_context.host_callbacks) - 1) - result = ffi.build_ffi_lowering_function( # type: ignore + result = ffi.build_ffi_lowering_function( call_target_name, has_side_effect=has_side_effect, )(ctx, *operands, index=np.uint64(index)) @@ -903,9 +903,9 @@ def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function if sharding is not None: mlir.set_sharding(result, sharding) - results = result.results # type: ignore + results = result.results if token: - token, *results = results # type: ignore + token, *results = results - return results, token, ifrt_callback # type: ignore + return results, token, ifrt_callback diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 92b879150922..41f6713c72f2 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -87,12 +87,12 @@ def __init__(self, traceback_info): def __init_subclass__(cls): jtu.register_pytree_node_class(cls) - def tree_flatten(self): + def tree_flatten(self, /): return ([], self.traceback_info) @classmethod - def tree_unflatten(cls, metadata, payload): - del payload + def tree_unflatten(cls, metadata, payload, /): + del payload # Unused. return cls(metadata) def get_effect_type(self) -> ErrorEffect: @@ -134,7 +134,8 @@ def tree_flatten(self): return ([], (self.traceback_info, self.prim)) @classmethod - def tree_unflatten(cls, metadata, _): + def tree_unflatten(cls, metadata, payload): + del payload return cls(*metadata) def get_effect_type(self): @@ -156,7 +157,7 @@ def tree_flatten(self): @classmethod def tree_unflatten(cls, metadata, payload): - return cls(*metadata, payload[0]) + return cls(*metadata, payload=payload[0]) def __str__(self): return (f'out-of-bounds indexing for array of ' @@ -227,11 +228,11 @@ def get(self) -> str | None: def get_exception(self) -> JaxException | None: """Returns Python exception if error happened, None if no error happened.""" - if any(map(np.shape, self._pred.values())): + if any(np.shape(v) for v in self._pred.values()): return self._get_batched_exception() else: - min_code = None - cur_effect = None + min_code: Int | None = None + cur_effect: ErrorEffect | None = None for error_effect, code in self._code.items(): if self._pred[error_effect]: if min_code is None or code < min_code: @@ -255,8 +256,8 @@ def _get_batched_exception(self) -> BatchedError | None: shape = np.shape(list(self._pred.values())[0]) error_mapping = {} for idx in np.ndindex(*shape): - min_code = None - cur_effect = None + min_code: Int | None = None + cur_effect: ErrorEffect | None = None for error_effect, code in self._code.items(): if self._pred[error_effect][idx]: # type: ignore if min_code is None or code[idx] < min_code: # type: ignore[index] @@ -983,7 +984,7 @@ def shard_map_error_check( in_avals[i] = sharder(mesh, manual_axes, check_vma, new_in_specs[i], v) with (jshmap._extend_axis_env(mesh, manual_axes), - mesh_lib.use_abstract_mesh(jshmap._as_manual_mesh(mesh, manual_axes)), # type: ignore[arg-type] + mesh_lib.use_abstract_mesh(jshmap._as_manual_mesh(mesh, manual_axes)), config._check_vma(check_vma)): # jaxpr to checked_jaxpr checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr( diff --git a/jax/_src/clusters/cloud_tpu_cluster.py b/jax/_src/clusters/cloud_tpu_cluster.py index 65ce1151533d..82e6a45686c1 100644 --- a/jax/_src/clusters/cloud_tpu_cluster.py +++ b/jax/_src/clusters/cloud_tpu_cluster.py @@ -154,7 +154,7 @@ def _get_num_slices() -> int: num_slices = get_tpu_env_value('MEGASCALE_NUM_SLICES') if not num_slices: return 1 - return int(num_slices) # type: ignore + return int(num_slices) @staticmethod @@ -162,7 +162,7 @@ def _get_slice_id() -> int: slice_id = get_tpu_env_value('MEGASCALE_SLICE_ID') if not slice_id: return 0 - return int(slice_id) # type: ignore + return int(slice_id) @staticmethod def _get_process_id_in_slice() -> int: diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index 4ff52bb4f974..3bd4c6f748e4 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -123,9 +123,13 @@ def __init__(self, base_cache: CacheInterface): self._verified_keys: set[str] = set() @property - def _path(self): + def _path(self): # pyrefly: ignore[bad-override] return self._base_cache._path + @_path.setter + def _path(self, value): + self._base_cache._path = value + def get(self, key: str) -> bytes | None: if key not in self._verified_keys: # Force a recompile the first time we see a key. diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 1fbdc189b27b..144970f58787 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -334,7 +334,7 @@ def backend_compile_and_load( # TODO(dsuo): Simplify this logic once we delete _jax.CompileOnlyPyClient. if isinstance(backend, _jax.CompileOnlyPyClient): if host_callbacks: - return backend.compile( + return backend.compile( # type: ignore module, executable_devices=executable_devices, # type: ignore compile_options=options, diff --git a/jax/_src/compute_on.py b/jax/_src/compute_on.py index 74d7edb3ea1f..2ad8a9a713a1 100644 --- a/jax/_src/compute_on.py +++ b/jax/_src/compute_on.py @@ -140,8 +140,10 @@ def _compute_on_lowering(ctx, *args, jaxpr, compute_type, out_memory_spaces): tokens, out_nodes = split_list(out_nodes, [len(effects)]) tokens_out = ctx.tokens_in.update_tokens(mlir.TokenSet(zip(effects, tokens))) ctx.set_tokens_out(tokens_out) - return [mlir.wrap_with_memory_kind(on, core.mem_space_to_kind(oms), out_aval) - for on, out_aval, oms in zip(out_nodes, ctx.avals_out, out_memory_spaces)] + return [ + mlir.wrap_with_memory_kind(on, core.mem_space_to_kind(oms), out_aval) # pyrefly: ignore[bad-argument-type] + for on, out_aval, oms in zip(out_nodes, ctx.avals_out, out_memory_spaces) + ] mlir.register_lowering(compute_on_p, _compute_on_lowering) diff --git a/jax/_src/config.py b/jax/_src/config.py index f67e3f975b1f..27056bfcb51b 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -906,7 +906,7 @@ def __bool__(self) -> NoReturn: raise TypeError( "bool() not supported for instances of type '{0}' " "(did you mean to use '{0}.value' instead?)".format( - type(self).__name__)) # pyrefly: ignore[missing-attribute] # pyrefly#2444 + type(self).__name__)) def _set(self, value: _T) -> None: self.value = value diff --git a/jax/_src/core.py b/jax/_src/core.py index c96b6c862853..9c90d814cea7 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -196,7 +196,7 @@ def pretty_print(self, *, source_info=False, print_shapes=True, self, source_info=source_info, print_shapes=print_shapes, custom_pp_eqn_rules=custom_pp_eqn_rules, name_stack=name_stack, print_effects=print_effects) - return doc.format(**kwargs) + return doc.format(**kwargs) # pyrefly: ignore[missing-attribute] def _repr_pretty_(self, p, cycle): return p.text(self.pretty_print(use_color=True)) @@ -653,7 +653,7 @@ def _true_bind(self, *args, **params): finally: trace_ctx.set_trace(prev_trace) - def bind_with_trace(self, trace, args, params): + def bind_with_trace(self, trace, args, params, /): # TODO(mattjj,dougalm): remove this block? try: in_type = map(typeof, args) except: pass # try lojax error message @@ -785,7 +785,7 @@ def __init__(self): self._weakref = weakref.ref(self) self.requires_low = True - def process_primitive(self, primitive, tracers, params): + def process_primitive(self, primitive, tracers, params, /): raise NotImplementedError("must override") def invalidate(self): @@ -797,29 +797,29 @@ def is_valid(self): def __repr__(self): return f'{self.__class__.__name__}' - def process_call(self, call_primitive, f, tracers, params): + def process_call(self, call_primitive, f, tracers, params, /): msg = (f"{type(self)} must override process_call to handle call-like " "primitives") raise NotImplementedError(msg) - def process_map(self, map_primitive, f, tracers, params): + def process_map(self, map_primitive, f, tracers, params, /): msg = (f"{type(self)} must override process_map to handle map-like " "primitives") raise NotImplementedError(msg) - def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *, + def process_custom_jvp_call(self, primitive, fun, jvp, tracers, /, *, symbolic_zeros): msg = (f"{type(self)} must override process_custom_jvp_call " "to handle custom_jvp primitives") raise NotImplementedError(msg) def process_custom_transpose(self, prim: Primitive, - call: lu.WrappedFun, tracers, **params): + call: lu.WrappedFun, tracers, /, **params): msg = (f"{type(self)} must override process_custom_transpose " "to handle custom_transpose_call primitives") raise NotImplementedError(msg) - def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, + def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, /, *, out_trees, symbolic_zeros): msg = (f"{type(self)} must override process_custom_vjp_call " "to handle custom_vjp primitives") @@ -958,12 +958,16 @@ def full_lower(self): raise NotImplementedError("must override: ", type(self)) def __iter__(self): + if not hasattr(self.aval, "_iter"): + raise TypeError(f"Value of type {type(self)} is not iterable.") return iter(self.aval._iter(self)) def __reversed__(self): return iter(self[::-1]) def __len__(self): + if not hasattr(self.aval, "_len"): + raise TypeError(f"Value of type {type(self)} has no length.") return self.aval._len(self) def to_concrete_value(self): @@ -1003,10 +1007,12 @@ def addressable_shards(self): @property def at(self): + if not hasattr(self.aval, "at"): + raise TypeError(f"Value of type {type(self)} does not support at().") return self.aval.at.fget(self) @property - def aval(self): + def aval(self) -> AbstractValue: raise NotImplementedError("must override") def get_referent(self) -> Any: @@ -1015,34 +1021,48 @@ def get_referent(self) -> Any: def __bool__(self): if is_concrete(self): return bool(self.to_concrete_value()) # pytype: disable=wrong-arg-types check_bool_conversion(self) + if not hasattr(self.aval, "_bool"): + raise TypeError(f"Value of type {type(self)} is not convertible to boolean.") return self.aval._bool(self) def __int__(self): if is_concrete(self): return int(self.to_concrete_value()) # pytype: disable=wrong-arg-types check_scalar_conversion(self) + if not hasattr(self.aval, "_int"): + raise TypeError(f"Value of type {type(self)} is not convertible to integer.") return self.aval._int(self) def __float__(self): check_scalar_conversion(self) + if not hasattr(self.aval, "_float"): + raise TypeError(f"Value of type {type(self)} is not convertible to float.") return self.aval._float(self) def __complex__(self): check_scalar_conversion(self) + if not hasattr(self.aval, "_complex"): + raise TypeError(f"Value of type {type(self)} is not convertible to complex.") return self.aval._complex(self) def __hex__(self): if is_concrete(self): return hex(self.to_concrete_value()) # pytype: disable=wrong-arg-types check_integer_conversion(self) + if not hasattr(self.aval, "_hex"): + raise TypeError(f"Value of type {type(self)} is not convertible to hex.") return self.aval._hex(self) def __oct__(self): if is_concrete(self): return oct(self.to_concrete_value()) # pytype: disable=wrong-arg-types check_integer_conversion(self) + if not hasattr(self.aval, "_oct"): + raise TypeError(f"Value of type {type(self)} is not convertible to oct.") return self.aval._oct(self) def __index__(self): if is_concrete(self): return operator.index(self.to_concrete_value()) # pytype: disable=wrong-arg-types check_integer_conversion(self) + if not hasattr(self.aval, "_index"): + raise TypeError(f"Value of type {type(self)} is not convertible to integer index.") return self.aval._index(self) # raises a useful error on attempts to pickle a Tracer. @@ -1052,15 +1072,28 @@ def __reduce__(self): "indicate an attempt to serialize/pickle a traced value.")) # raises the better error message from ShapedArray - def __setitem__(self, idx, val): return self.aval._setitem(self, idx, val) + def __setitem__(self, key, value): + if not hasattr(self.aval, "_setitem"): + raise TypeError(f"Value of type {type(self)} is not indexable.") + return self.aval._setitem(self, key, value) # NumPy also only looks up special methods on classes. - def __array_module__(self, types): return self.aval._array_module(self, types) + def __array_module__(self, types): + if not hasattr(self.aval, "_array_module"): + raise TypeError(f"Value of type {type(self)} is not compatible with the Array API.") + return self.aval._array_module(self, types) def __getattr__(self, name): # if the aval property raises an AttributeError, gets caught here assert not config.enable_checks.value or name != "aval" + # These must raise AttributeError in the base class for backward compatibility. + # TODO(jakevdp): can we change this and make them raise NotImplementedError instead? + if name in ["block_until_ready", "copy_to_host_async"]: + raise AttributeError( + f"The '{name}' method is not available on {self._error_repr()}." + f"{self._origin_msg()}") + if name == 'sharding': raise AttributeError( f"The 'sharding' attribute is not available on {self._error_repr()}. " @@ -1100,7 +1133,7 @@ def _pretty_print(self, verbose: bool = False) -> pp.Doc: return base def __repr__(self): - return self._pretty_print(verbose=False).format() + return self._pretty_print(verbose=False).format() # pyrefly: ignore[missing-attribute] def _contents(self): try: @@ -1117,20 +1150,6 @@ def addressable_data(self, index): f"The addressable_data() method was called on {self._error_repr()}." f"{self._origin_msg()}") - @property - def block_until_ready(self): - # Raise AttributeError for backward compatibility with hasattr() and getattr() checks. - raise AttributeError( - f"The 'block_until_ready' method is not available on {self._error_repr()}." - f"{self._origin_msg()}") - - @property - def copy_to_host_async(self): - # Raise AttributeError for backward compatibility with hasattr() and getattr() checks. - raise AttributeError( - f"The 'copy_to_host_async' method is not available on {self._error_repr()}." - f"{self._origin_msg()}") - def delete(self): raise ConcretizationTypeError(self, f"The delete() method was called on {self._error_repr()}." @@ -1196,7 +1215,7 @@ def check_eval_args(args): class EvalTrace(Trace): - def process_primitive(self, primitive, args, params): + def process_primitive(self, primitive, args, params, /): if config.debug_key_reuse.value: # Import here to avoid circular imports from jax.experimental.key_reuse._core import call_impl_with_key_reuse_checks # pytype: disable=import-error @@ -1207,7 +1226,7 @@ def process_primitive(self, primitive, args, params): check_eval_args(args) return primitive.impl(*args, **params) - def process_call(self, primitive, f, tracers, params): + def process_call(self, primitive, f, tracers, params, /): if config.debug_key_reuse.value: # Import here to avoid circular imports from jax.experimental.key_reuse._core import call_impl_with_key_reuse_checks # pytype: disable=import-error @@ -1216,15 +1235,15 @@ def process_call(self, primitive, f, tracers, params): return primitive.impl(f, *tracers, **params) process_map = process_call - def process_custom_transpose(self, primitive, call, tracers, **_): + def process_custom_transpose(self, primitive, call, tracers, /, **_): del primitive, _ return call.call_wrapped(*tracers) - def process_custom_jvp_call(self, primitive, fun, jvp, tracers, **_): + def process_custom_jvp_call(self, primitive, fun, jvp, tracers, /, **_): del primitive, jvp, _ # Unused. return fun.call_wrapped(*tracers) - def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, **_): # pytype: disable=signature-mismatch + def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, /, **_): del primitive, fwd, bwd, _ # Unused. return fun.call_wrapped(*tracers) @@ -1680,7 +1699,7 @@ def has_qdd(self) -> bool: def to_tangent_aval(self) -> AbstractValue: raise NotImplementedError("must override") - def to_cotangent_aval(self) -> AbstractValue: + def to_ct_aval(self) -> AbstractValue: raise NotImplementedError("must override") # TODO(dougalm): deprecate this alias @@ -1733,6 +1752,10 @@ def shard(self, mesh, manual_axes, check_vma, spec): def unshard(self, mesh, check_vma, spec): return unshard_aval(mesh, check_vma, spec, self) + def vspace_add(self, x, y): + from jax._src.ad_util import add_jaxvals # type: ignore + return add_jaxvals(x, y) + InputType = tuple[AbstractValue, ...] OutputType = tuple[AbstractValue, ...] @@ -1751,7 +1774,7 @@ def valid_jaxtype(x) -> bool: return True -def mem_kind_to_space(mem_kind: str) -> MemorySpace: +def mem_kind_to_space(mem_kind: str | None) -> MemorySpace: if mem_kind == 'pinned_host': return MemorySpace.Host return MemorySpace.Device @@ -2301,7 +2324,7 @@ def to_tangent_aval(self): self.weak_type, sharding=self.sharding, vma=self.vma, memory_space=self.memory_space) - def to_cotangent_aval(self): + def to_ct_aval(self): dtype = primal_dtype_to_tangent_dtype(self.dtype) sharding = primal_sharding_to_cotangent_sharding(self.sharding) return ShapedArray( @@ -2325,6 +2348,11 @@ def update_vma(self, vma): def update_weak_type(self, weak_type): return self.update(weak_type=weak_type) + def nospec(self, mesh, check_vma, all_names) -> P: + # TODO(mattjj, yashkatariya): should use newly all_names in check_vma path? + all_names = order_wrt_mesh(mesh, self.vma) if check_vma else all_names + return P(all_names) if all_names else P() + _bool = concretization_function_error(bool) _int = concretization_function_error(int, True) _float = concretization_function_error(float, True) @@ -2391,11 +2419,8 @@ def primal_dtype_to_tangent_dtype(primal_dtype): else: return primal_dtype -def primal_spec_to_cotangent_spec(spec): - return P(*spec, unreduced=spec.reduced, reduced=spec.unreduced) - def primal_sharding_to_cotangent_sharding(sharding): - return sharding.update(spec=primal_spec_to_cotangent_spec(sharding.spec)) + return sharding.update(spec=sharding.spec.to_ct_spec()) ############################## pvary ################################# @@ -2558,7 +2583,7 @@ def unsafe_buffer_pointer(self): return self._refs._buf.unsafe_buffer_pointer() def at(self): raise NotImplementedError() # TODO(mattjj) class ArrayRefImpl: - _aval: ShapedArray + _aval: AbstractValue _buf: Array # mutable field def __init__(self, aval, buf): @@ -2711,7 +2736,7 @@ def accum_grad_in_ref(x): class AbstractToken(AbstractValue): def str_short(self, short_dtypes=False, mesh_axis_types=False): return 'Tok' def to_tangent_aval(self): return self - def to_cotangent_aval(self): return self + def to_ct_aval(self): return self abstract_token: AbstractToken = AbstractToken() # Singleton shaped array used by all abstract tokens when shape/dtype is needed. @@ -2997,7 +3022,7 @@ class CallPrimitive(Primitive): def bind(self, *args, **params): return self._true_bind(*args, **params) - def bind_with_trace(self, trace, fun_and_args, params): + def bind_with_trace(self, trace, fun_and_args, params, /): fun = fun_and_args[0] args = fun_and_args[1:] return trace.process_call(self, fun, args, params) @@ -3040,7 +3065,7 @@ class MapPrimitive(Primitive): def bind(self, *args, **params): return self._true_bind(*args, **params) - def bind_with_trace(self, trace, fun_and_args, params): + def bind_with_trace(self, trace, fun_and_args, params, /): fun: lu.WrappedFun = fun_and_args[0] args = fun_and_args[1:] assert len(params['in_axes']) == len(args) @@ -4035,6 +4060,7 @@ def __eq__(self, other): def get_opaque_trace_state(convention=None): del convention + assert trace_ctx.trace is not None return OpaqueTraceState(trace_ctx.trace._weakref) def nonempty_axis_env() -> bool: diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index f4c2e648921b..cd75c07d2724 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -1817,7 +1817,7 @@ def combine_bias_and_mask(bias, mask, dtype): large_negative_number = get_large_negative_number(dtype) mask = jnp.where(mask, jnp.asarray(0, dtype), large_negative_number) # reshape mask to have 4D shape - mask = mask.reshape((1,) * (4 - len(mask.shape)) + mask.shape) # type: ignore[union-attr] + mask = mask.reshape((1,) * (4 - len(mask.shape)) + mask.shape) # combine bias and mask if bias is None: @@ -1905,7 +1905,7 @@ def paged_attention( page_table_k, page_table_v, layout) has_bias = bias is not None has_dbias = has_bias and \ - should_export_dbias(bias.shape, query.shape, layout) # type: ignore[union-attr] + should_export_dbias(bias.shape, query.shape, layout) variadic_args = (has_bias, has_dbias) _not_used = jnp.zeros(0, dtype=query.dtype) @@ -2042,7 +2042,7 @@ def dot_product_attention( None, None, layout) has_bias = bias is not None has_dbias = has_bias and \ - should_export_dbias(bias.shape, query.shape, layout) # type: ignore[union-attr] + should_export_dbias(bias.shape, query.shape, layout) variadic_args = (has_bias, has_dbias) _not_used = jnp.zeros(0, dtype=query.dtype) diff --git a/jax/_src/cudnn/scaled_matmul_stablehlo.py b/jax/_src/cudnn/scaled_matmul_stablehlo.py index c7dcdbd2ac6b..01e8a681e899 100644 --- a/jax/_src/cudnn/scaled_matmul_stablehlo.py +++ b/jax/_src/cudnn/scaled_matmul_stablehlo.py @@ -684,7 +684,7 @@ def scaled_dot_bwd(dimension_numbers, preferred_element_type, configs, res, g): "configs": [configs[2], configs[0]] } grad_lhs = scaled_dot_general_transpose_lhs(*args, **lhs_kw_args) - grad_rhs = scaled_dot_general_transpose_rhs(*args, **rhs_kw_args) # pyrefly: ignore[bad-argument-type] + grad_rhs = scaled_dot_general_transpose_rhs(*args, **rhs_kw_args) # We apply a Straight-Through Estimator (STE) with zero-out behavior: if # inputs are clipped during quantization in fprop, their corresponding gradients diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 33b4da6dfe95..9dc5c859f6d5 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -389,11 +389,11 @@ class CustomJVPCallPrimitive(core.Primitive): def bind(self, *args, **params): return self._true_bind(*args, **params) - def bind_with_trace(self, trace, args, params): + def bind_with_trace(self, trace, args, params, /): fun, jvp, tracers = args[0], args[1], args[2:] return trace.process_custom_jvp_call(self, fun, jvp, tracers, **params) - def impl(self, fun, _, *args): # type: ignore[bad-override] + def impl(self, fun, _, *args): raise NotImplementedError def get_bind_params(self, params): @@ -556,7 +556,7 @@ def f_bwd(res, g): def __new__(cls, fun, nondiff_argnums=(), nondiff_argnames=()): if config.custom_vjp3.value: - from jax._src.hijax import custom_vjp3 # type: ignore + from jax._src.hijax import custom_vjp3 # pytype: disable=import-error return custom_vjp3(fun, nondiff_argnums, nondiff_argnames) else: return super().__new__(cls) @@ -952,7 +952,7 @@ def append(x, d): if ct is zero or getattr(a.to_tangent_aval(), 'dtype') == dtypes.float0: results.append(Zero(a.to_tangent_aval())) elif type(ct) is SymbolicZero: - if not core.typecompat(a.to_cotangent_aval(), a_ := ct.aval): + if not core.typecompat(a.to_ct_aval(), a_ := ct.aval): msg = ("Custom VJP bwd rule produced a SymbolicZero with a shape/dtype " "that does not match the corresponding input tangent shape/dtype: " f"at output{keystr(kp)} the SymbolicZero had shape/dtype " @@ -964,9 +964,9 @@ def append(x, d): results.append(Zero(ct.aval)) else: if (not config.disable_bwd_checks.value and - not core.typecompat(a.to_cotangent_aval(), a_ := core.typeof(ct)) - and not _ref_typecompat(a.to_cotangent_aval(), a_) - and not _temporary_dtype_exception(a.to_cotangent_aval(), a_)): + not core.typecompat(a.to_ct_aval(), a_ := core.typeof(ct)) + and not _ref_typecompat(a.to_ct_aval(), a_) + and not _temporary_dtype_exception(a.to_ct_aval(), a_)): msg = ("Custom VJP bwd rule must produce an output with the same " "type as the args tuple of the primal function, but at " f"output{keystr(kp)} the bwd rule produced an output of " @@ -979,7 +979,7 @@ def append(x, d): def _ref_typecompat(a, a_): return (isinstance(a, AbstractRef) and - core.typecompat(a.to_cotangent_aval().inner_aval, a_)) + core.typecompat(a.to_ct_aval().inner_aval, a_)) # TODO(mattjj): remove both these exceptions to cotangent compatibility check def _temporary_dtype_exception(a, a_) -> bool: @@ -997,11 +997,11 @@ class CustomVJPCallPrimitive(core.Primitive): def bind(self, *args, **params): return self._true_bind(*args, **params) - def bind_with_trace(self, trace, args, params): + def bind_with_trace(self, trace, args, params, /): fun, fwd, bwd, tracers = args[0], args[1], args[2], args[3:] return trace.process_custom_vjp_call(self, fun, fwd, bwd, tracers, **params) - def impl(self, fun, fwd, bwd, *args): # type: ignore[bad-override] + def impl(self, fun, fwd, bwd, *args): raise NotImplementedError def get_bind_params(self, params): diff --git a/jax/_src/custom_transpose.py b/jax/_src/custom_transpose.py index 1762522a8deb..b8ea0d19a360 100644 --- a/jax/_src/custom_transpose.py +++ b/jax/_src/custom_transpose.py @@ -83,6 +83,11 @@ def def_transpose(self, transpose: Callable): @traceback_util.api_boundary def __call__(self, out_types, res_arg, lin_arg): + if self.transpose is None: + raise ValueError( + "Missing a transpose function. Use @def_transpose to define one." + ) + _, res_tree = tree_flatten(res_arg) _, lin_tree = tree_flatten(lin_arg) args_flat, in_tree = tree_flatten((res_arg, lin_arg)) @@ -170,7 +175,7 @@ class CustomTransposePrimitive(core.Primitive): def bind(self, *args, **params): return self._true_bind(*args, **params) - def bind_with_trace(self, trace, call_args, params): + def bind_with_trace(self, trace, call_args, params, /): call, tracers = call_args[0], call_args[1:] return trace.process_custom_transpose(self, call, tracers, **params) diff --git a/jax/_src/debugger/colab_debugger.py b/jax/_src/debugger/colab_debugger.py index 57d785a87613..df576057fede 100644 --- a/jax/_src/debugger/colab_debugger.py +++ b/jax/_src/debugger/colab_debugger.py @@ -30,6 +30,8 @@ from google.colab import output try: import pygments + import pygments.lexers + import pygments.formatters IS_PYGMENTS_ENABLED = True except ImportError: IS_PYGMENTS_ENABLED = False diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 82b1ff67b57d..b842586b291a 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -343,7 +343,7 @@ def _different_device_order_reshard( new_mesh, inp_sharding.spec, memory_kind=target_sharding.memory_kind, _logical_device_ids=(None if permute_order is None else tuple(permute_order.tolist()))) - new_x = xc.reorder_shards(x, new_s, ArrayCopySemantics.REUSE_INPUT) # type: ignore + new_x = xc.reorder_shards(x, new_s, ArrayCopySemantics.REUSE_INPUT) return api.jit(_identity_fn, out_shardings=target_sharding, donate_argnums=donate_argnums)(new_x) diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index d380d6f8db38..1f9f59ec1fae 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -24,7 +24,6 @@ from jax._src import config from jax._src import xla_bridge from jax._src.lib import _jax -from jax._src.lib import jaxlib_extension_version logger = logging.getLogger(__name__) @@ -150,32 +149,26 @@ def initialize(self, logger.info( 'Starting JAX distributed service on %s', coordinator_bind_address ) - if jaxlib_extension_version >= 403: - self.service = _jax.get_distributed_runtime_service( - coordinator_bind_address, num_processes, - heartbeat_timeout=heartbeat_timeout_seconds, - shutdown_timeout=shutdown_timeout_seconds, - recoverable=_ENABLE_RECOVERABILITY.value) # type: ignore - else: - self.service = _jax.get_distributed_runtime_service( - coordinator_bind_address, num_processes, - heartbeat_timeout=heartbeat_timeout_seconds, - shutdown_timeout=shutdown_timeout_seconds) # type: ignore + self.service = _jax.get_distributed_runtime_service( + coordinator_bind_address, + num_processes, + heartbeat_timeout=heartbeat_timeout_seconds, + shutdown_timeout=shutdown_timeout_seconds, + recoverable=_ENABLE_RECOVERABILITY.value, + ) self.num_processes = num_processes if self.client is not None: raise RuntimeError('distributed.initialize should only be called once.') - if jaxlib_extension_version >= 405: - self.client = _jax.get_distributed_runtime_client( - coordinator_address, process_id, init_timeout=initialization_timeout, - use_compression=True, heartbeat_timeout=heartbeat_timeout_seconds) # type: ignore - else: - self.client = _jax.get_distributed_runtime_client( - coordinator_address, process_id, init_timeout=initialization_timeout, - use_compression=True, heartbeat_timeout=heartbeat_timeout_seconds, - recoverable=_ENABLE_RECOVERABILITY.value) # type: ignore + self.client = _jax.get_distributed_runtime_client( + coordinator_address, + process_id, + init_timeout=initialization_timeout, + use_compression=True, + heartbeat_timeout=heartbeat_timeout_seconds, + ) logger.info('Connecting to JAX distributed service on %s', coordinator_address) self.client.connect() diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 1fcb9943b334..231189e1936c 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -506,8 +506,8 @@ def issubdtype(a: DTypeLike | ExtendedDType | None, # unhashable (e.g. custom objects with a dtype attribute). The following check is # fast and covers the majority of calls to this function within JAX library code. return _issubdtype_cached( - a if isinstance(a, _types_for_issubdtype) else np.dtype(a), # type: ignore[arg-type] - b if isinstance(b, _types_for_issubdtype) else np.dtype(b), # type: ignore[arg-type] + a if isinstance(a, _types_for_issubdtype) else np.dtype(a), + b if isinstance(b, _types_for_issubdtype) else np.dtype(b), ) @@ -1084,7 +1084,7 @@ def result_type(*args: Any, return_weak_type_flag: bool = False) -> DType | tupl if weak_type: dtype = default_types['f' if dtype in _custom_float_dtypes else dtype.kind]() # TODO(jakevdp): fix return type annotation and remove this ignore. - return (dtype, weak_type) if return_weak_type_flag else dtype # type: ignore[return-value] + return (dtype, weak_type) if return_weak_type_flag else dtype def check_and_canonicalize_user_dtype(dtype, fun_name=None) -> DType: """Checks validity of a user-provided dtype, and returns its canonical form. diff --git a/jax/_src/earray.py b/jax/_src/earray.py index b45d371057e4..ca8b284ba5d6 100644 --- a/jax/_src/earray.py +++ b/jax/_src/earray.py @@ -29,14 +29,18 @@ # EArray is an Array that can contain extended dtypes. class EArray(basearray.Array): - __slots__ = ['aval', '_data'] + __slots__ = ['_aval', '_data'] __hash__ = None # type: ignore[assignment] __array_priority__ = 100 def __init__(self, aval, data): - self.aval = aval + self._aval = aval self._data = data + @property + def aval(self): + return self._aval + def block_until_ready(self): _ = self._data.block_until_ready() return self diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index fafc571da0a9..191d398972b7 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -813,8 +813,9 @@ def _export_lowered( cur_mesh = None if config.use_shardy_partitioner.value: - for sharding in itertools.chain.from_iterable([ - all_in_shardings, lowering.compile_args["out_shardings"]]): + for sharding in itertools.chain( + all_in_shardings, lowering.compile_args["out_shardings"] + ): if isinstance(sharding, sharding_impls.NamedSharding): cur_mesh = sharding.mesh break @@ -853,7 +854,7 @@ def _get_exported_vjp(exp_primal: Exported) -> Exported: apply_jit=True, flat_primal_fun=True, mesh=cur_mesh) # type: ignore[arg-type] - return export(fun_vjp_jax, # type: ignore[arg-type] + return export(fun_vjp_jax, # pytype: disable=wrong-arg-types platforms=exp_primal.platforms, disabled_checks=exp_primal.disabled_safety_checks)(*vjp_in_avals) @@ -902,9 +903,9 @@ def _module_to_bytecode(module: ir.Module) -> bytes: # Note that this does not verify any JAX custom calls, which are only # guaranteed 3w of forward compatibility, and only prevents use of new # StableHLO features from failing on older hardware. - target_version = hlo.get_version_from_compatibility_requirement( # pyrefly: ignore[missing-attribute] - hlo.StablehloCompatibilityRequirement.WEEK_4) # pyrefly: ignore[missing-attribute] - module_serialized = xla_client._xla.mlir.serialize_portable_artifact( # type: ignore + target_version = hlo.get_version_from_compatibility_requirement( + hlo.StablehloCompatibilityRequirement.WEEK_4) + module_serialized = xla_client._xla.mlir.serialize_portable_artifact( mlir_str, target_version, xb.get_backend().serialize_with_sdy) return module_serialized @@ -949,8 +950,8 @@ def _wrap_main_func( def is_token(typ, attrs): return (typ == mlir.token_type()) - orig_input_types = orig_main.type.inputs # type: ignore - arg_attrs = list(ir.ArrayAttr(orig_main.arg_attrs)) # type: ignore + orig_input_types = orig_main.type.inputs + arg_attrs = list(ir.ArrayAttr(orig_main.arg_attrs)) # The order of args: platform_index_arg, dim args, token args, array args. nr_platform_index_args = 1 if has_platform_index_argument else 0 nr_dim_args = len(dim_vars) @@ -972,8 +973,8 @@ def is_token(typ, attrs): orig_input_types, [nr_platform_index_args, nr_dim_args, nr_token_args]) # The order of results: tokens, array results - orig_output_types = orig_main.type.results # type: ignore - result_attrs = list(ir.ArrayAttr(orig_main.result_attrs)) # type: ignore + orig_output_types = orig_main.type.results + result_attrs = list(ir.ArrayAttr(orig_main.result_attrs)) token_result_idxs = [i for i, (typ, attrs) in enumerate(zip(orig_output_types, result_attrs)) if is_token(typ, attrs)] @@ -1374,7 +1375,7 @@ def flattened_primal_fun_jax(*args_flat): if apply_jit: if has_named_shardings or mesh: vjp_in_shardings = tuple( - _get_named_sharding(has_named_shardings, named_sharding, # type: ignore + _get_named_sharding(has_named_shardings, named_sharding, hlo_sharding, aval, mesh) # type: ignore[arg-type] for named_sharding, hlo_sharding, aval in zip( itertools.chain(in_named_shardings, out_named_shardings), @@ -1516,7 +1517,7 @@ def pp_arg_dim(dim_idx: int | None) -> str: # it would be ambiguous whether we should continue tracing with a result # of type `f32[c]` or `f32[d]`. shape_constraints.check_statically(synthetic_eval) - exported_dim_values = [synthetic_eval.evaluate(solution[var]) # type: ignore[arg-type] + exported_dim_values = [synthetic_eval.evaluate(solution[var]) for var in exported_dim_vars] def make_aval(out_aval_idx: int): @@ -1549,8 +1550,8 @@ def _call_exported_impl(*args, exported: Exported): def get_mesh_from_symbol(symtab: ir.SymbolTable) -> mesh_lib.AbstractMesh: if "mesh" not in symtab: return mesh_lib.empty_abstract_mesh - mesh_attr = sdy.MeshAttr(symtab["mesh"].mesh) # pyrefly: ignore[missing-attribute] - axes = [sdy.MeshAxisAttr(a) for a in mesh_attr.axes] # pyrefly: ignore[missing-attribute] + mesh_attr = sdy.MeshAttr(symtab["mesh"].mesh) + axes = [sdy.MeshAxisAttr(a) for a in mesh_attr.axes] if not axes: return mesh_lib.empty_abstract_mesh axes_sizes = tuple(a.size for a in axes) @@ -1625,7 +1626,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, ctx, x, x_aval, _get_named_sharding(exported._has_named_shardings, named_sharding, None, x_aval, None), - use_shardy=True) # type: ignore[arg-type] + use_shardy=True) for x, named_sharding, x_aval in zip( args, exported._in_named_shardings, exported.in_avals)) elif mesh: @@ -1634,7 +1635,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, wrap_with_sharding( ctx, x, x_aval, _get_named_sharding(False, None, hlo_sharding, x_aval, mesh), - use_shardy=True) # type: ignore[arg-type] + use_shardy=True) for x, hlo_sharding, x_aval in zip( args, exported.in_shardings_hlo, exported.in_avals)) else: @@ -1737,7 +1738,7 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, wrap_with_sharding( ctx, x, x_aval, _get_named_sharding(True, x_sharding, None, x_aval, None), - use_shardy=True) # type: ignore[arg-type] + use_shardy=True) for x, x_aval, x_sharding in \ zip(results, ctx.avals_out, exported._out_named_shardings)) elif mesh: @@ -1745,7 +1746,7 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, wrap_with_sharding( ctx, x, x_aval, _get_named_sharding(False, None, x_sharding, x_aval, mesh), - use_shardy=True) # type: ignore[arg-type] + use_shardy=True) for x, x_aval, x_sharding in \ zip(results, ctx.avals_out, exported.out_shardings_hlo)) else: diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index 579552924be7..33502e88f9a6 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -98,11 +98,11 @@ def _serialize_exported( # _has_named_shardings. in_shardings = _serialize_array( builder, partial(_serialize_sharding, has_named_sharding=exp._has_named_shardings), - zip(exp._in_named_shardings, exp.in_shardings_hlo) # type: ignore + zip(exp._in_named_shardings, exp.in_shardings_hlo) ) out_shardings = _serialize_array( builder, partial(_serialize_sharding, has_named_sharding=exp._has_named_shardings), - zip(exp._out_named_shardings, exp.out_shardings_hlo) # type: ignore + zip(exp._out_named_shardings, exp.out_shardings_hlo) ) ordered_effects = _serialize_array( builder, _serialize_effect, exp.ordered_effects @@ -218,9 +218,9 @@ def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported: out_avals = _deserialize_tuple(exp.OutAvalsLength, exp.OutAvals, partial(_deserialize_aval, scope=scope, sharding=None)) in_shardings_hlo = cast(tuple[_export.HloSharding | None, ...], in_shardings) - in_shardings = (None,) * len(in_shardings) # type: ignore + in_shardings = (None,) * len(in_shardings) out_shardings_hlo = cast(tuple[_export.HloSharding | None, ...], out_shardings) - out_shardings = (None,) * len(out_shardings) # type: ignore + out_shardings = (None,) * len(out_shardings) platforms = _deserialize_tuple( exp.PlatformsLength, exp.Platforms, @@ -502,10 +502,10 @@ def _serialize_partition_spec(builder: flatbuffers.Builder, spec: partition_spec.PartitionSpec) -> int: partitions = _serialize_array(builder, _serialize_partition_spec_one_axis, spec._partitions) # pyrefly: ignore[bad-argument-type] - reduced = _serialize_array(builder, # type: ignore + reduced = _serialize_array(builder, lambda builder, ps: builder.CreateString(ps), spec.reduced) - unreduced = _serialize_array(builder, # type: ignore + unreduced = _serialize_array(builder, lambda builder, ps: builder.CreateString(ps), spec.unreduced) diff --git a/jax/_src/export/shape_poly.py b/jax/_src/export/shape_poly.py index e811e7d38469..c611eac99bc0 100644 --- a/jax/_src/export/shape_poly.py +++ b/jax/_src/export/shape_poly.py @@ -911,7 +911,7 @@ def _divmod(self, divisor: DimSize) -> tuple[DimSize, DimSize]: return quotient, remainder except InconclusiveDimensionOperation: return (_DimExpr._from_operation(_DimFactor.FLOORDIV, self, divisor, - scope=self.scope), # type: ignore + scope=self.scope), _DimExpr._from_operation(_DimFactor.MOD, self, divisor, scope=self.scope)) @@ -1402,7 +1402,8 @@ def symbolic_shape(shape_spec: str | None, scope: optionally, you can specify that the parsed symbolic expressions be created in the given scope. If this is missing, then a new `SymbolicScope` is created with the given `constraints`. - You cannot specify both a `scope` and `constraints`. + You cannot specify both a `scope` and `constraints` (cannot add new + constraints to a `scope`). See [the documentation](https://docs.jax.dev/en/latest/export/shape_poly.html#user-specified-symbolic-constraints) for usage. like: when `shape_spec` contains placeholders ("_", "..."), use this @@ -1632,7 +1633,7 @@ def expr(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]: t: Any t, tok = self.term(tok) t_sign = - t if next_t_negated else t - acc = acc + t_sign if acc is not None else t_sign # type: ignore[operator] + acc = acc + t_sign if acc is not None else t_sign if tok.exact_type in self.FOLLOW_EXPR: return acc, tok next_t_negated = (tok.exact_type == tokenize.MINUS) @@ -1654,7 +1655,7 @@ def term(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]: power, tok = self.integer(tok) f = f ** power - acc = acc * f if acc is not None else f # type: ignore[operator] + acc = acc * f if acc is not None else f if tok.exact_type in self.FOLLOW_TERM: return acc, tok # type: ignore[bad-return-type,unused-ignore] tok = self.consume_token(tok, tokenize.STAR) @@ -2029,7 +2030,7 @@ def compute_dim_vars_from_arg_shapes( } synthetic_eval = ShapeEvaluator(synthetic_env) shape_constraints.shape_assertions(synthetic_eval) - return tuple(synthetic_eval.evaluate(solution[var]) for var in dim_vars) # type: ignore[arg-type] + return tuple(synthetic_eval.evaluate(solution[var]) for var in dim_vars) def _solve_dim_equations( eqns: list[_DimEquation], diff --git a/jax/_src/ffi.py b/jax/_src/ffi.py index cca0ae121477..eec29be1a4be 100644 --- a/jax/_src/ffi.py +++ b/jax/_src/ffi.py @@ -177,7 +177,7 @@ def include_dir() -> str: def _aval_shape(aval: core.AbstractValue) -> Shape: - return () if aval is core.abstract_token else core.physical_aval(aval).shape # pytype: disable=attribute-error + return () if aval is core.abstract_token else core.physical_aval(aval).shape # pytype: disable=attribute-error # pyrefly: ignore[missing-attribute] def _convert_layout_for_lowering( @@ -326,7 +326,7 @@ def _lowering( **lowering_args, )(ctx, *operands, **params) - return result.results # type: ignore + return result.results return _lowering @@ -676,7 +676,7 @@ def ffi_call_lowering( operand_output_aliases=dict(input_output_aliases), api_version=custom_call_api_version, backend_config=legacy_backend_config) - return rule(ctx, *operands, **_unwrap_kwargs_hashable(attributes)) + return rule(ctx, *operands, **_unwrap_kwargs_hashable(attributes)) # pyrefly: ignore[bad-return] def ffi_batching_rule( diff --git a/jax/_src/hijax.py b/jax/_src/hijax.py index 633ca76d6094..0a82b837d5cc 100644 --- a/jax/_src/hijax.py +++ b/jax/_src/hijax.py @@ -83,6 +83,7 @@ def jvp(self, primals, tangents, **params): def transpose(self, *args, **params): assert False, "must override" +AxisName = Any class HiType(core.AbstractValue): is_high = True @@ -105,7 +106,7 @@ def raise_val(self, *lo_vals: LoVal) -> HiVal: # autodiff interface def to_tangent_aval(self) -> HiType: assert False, "must override" - def to_cotangent_aval(self) -> HiType: + def to_ct_aval(self) -> HiType: return self.to_tangent_aval() # the next two are required if this type is itself a tangent type def vspace_zero(self) -> HiVal: @@ -124,11 +125,15 @@ def leading_axis_spec(self) -> MappingSpec: assert False, "must override" # shard_map interface - def shard(self, mesh, manual_axes: frozenset, check_vma: bool, spec: HipSpec + def shard(self, mesh, manual_axes: frozenset, check_vma: bool, spec: HiPspec ) -> HiType: assert False, "must override" - def unshard(self, mesh, check_vma: bool, spec: HipSpec) -> HiType: + def unshard(self, mesh, check_vma: bool, spec: HiPspec) -> HiType: assert False, "must override" + def nospec(self, mesh, check_vma: bool, all_names: tuple[AxisName, ...] + ) -> HiPspec: + assert False, "must override" + class MutableHiType(core.AbstractValue): is_high = True @@ -161,7 +166,7 @@ def to_tangent_aval(self) -> HiType: # Subclasses should override if the cotangent type is a function of primal # type. For example, CT unreduced = reduced and vice-versa. - def to_cotangent_aval(self) -> HiType: + def to_ct_aval(self) -> HiType: return self.to_tangent_aval() def register_hitype(val_cls, typeof_fn) -> None: @@ -297,18 +302,18 @@ class BoxEffect(effects.Effect): ... class NewBox(HiPrimitive): def is_high(self, *, treedef) -> bool: return True # type: ignore - def abstract_eval(self, *, treedef): # pyrefly: ignore[bad-override] + def abstract_eval(self, *, treedef): leaves, treedef = tree_flatten(None) qdd = BoxTypeState(tuple(leaves), treedef) return core.AvalQDD(BoxTy(), qdd), {box_effect} - def to_lojax(_, *, treedef): # pyrefly: ignore[bad-override] + def to_lojax(_, *, treedef): return Box._new(None) def jvp(_, primals, tangents, *, treedef): # pyrefly: ignore[bad-override] assert False # TODO - def transpose(_, *args, treedef): # pyrefly: ignore[bad-override] + def transpose(_, *args, treedef): assert False # TODO new_box_p = NewBox('new_box') @@ -317,11 +322,11 @@ class BoxSet(HiPrimitive): def is_high(self, *leaf_avals, treedef) -> bool: return True # type: ignore - def abstract_eval(self, box_ty, *leaf_avals, treedef): # pyrefly: ignore[bad-override] + def abstract_eval(self, box_ty, *leaf_avals, treedef): box_ty.mutable_qdd.update(BoxTypeState(leaf_avals, treedef)) return [], {box_effect} # TODO better typechecking... - def to_lojax(_, box, *leaves, treedef): # pyrefly: ignore[bad-override] + def to_lojax(_, box, *leaves, treedef): box._val = tree_unflatten(treedef, leaves) return [] @@ -335,7 +340,7 @@ def jvp(_, primals, tangents, *, treedef): # pyrefly: ignore[bad-override] box_set_p.bind(box_dot, *val_dots, treedef=treedef) return [], [] - def transpose(_, *args, treedef): # pyrefly: ignore[bad-override] + def transpose(_, *args, treedef): assert False # TODO box_set_p = BoxSet('box_set') @@ -343,10 +348,10 @@ def transpose(_, *args, treedef): # pyrefly: ignore[bad-override] class BoxGet(HiPrimitive): multiple_results = True - def abstract_eval(self, box_ty, *, avals): # pyrefly: ignore[bad-override] + def abstract_eval(self, box_ty, *, avals): return avals, {box_effect} - def to_lojax(_, box, *, avals): # pyrefly: ignore[bad-override] + def to_lojax(_, box, *, avals): return tree_leaves(box._val) def jvp(_, primals, tangents, *, avals): # pyrefly: ignore[bad-override] @@ -356,7 +361,7 @@ def jvp(_, primals, tangents, *, avals): # pyrefly: ignore[bad-override] box_get_p.bind(box_dot, avals=tuple(a.to_tangent_aval() for a in avals)) ) - def transpose(_, *args): # pyrefly: ignore[bad-override] + def transpose(_, *args): assert False # TODO box_get_p = BoxGet('box_get') @@ -414,6 +419,11 @@ def linearized(self, residuals, *tangents): raise NotImplementedError(f"for linearize support, subclass {type(self)} " "must implement `lin` and `linearized`") + # optional transpose rule, for primitives that are linear in some inputs + def transpose(self, out_ct, *maybe_accums): + raise NotImplementedError(f"for transpose support, subclass {type(self)} " + "must implement `transpose`") + # vmap interface def batch(self, axis_data, args, dims): out_dim = self.batch_dim_rule(axis_data, dims) @@ -510,13 +520,8 @@ def vjp_bwd_retval(self, res_, g): return tree_map(partial(unmap_zero, self.axis_data), self.in_dims, out, is_leaf=lambda x: x is None) # type: ignore def batch_dim_rule(self, axis_data, in_dims): - - def fix_dim(dim, prev_dim): - if dim is None: - return None - return dim if prev_dim is None else (dim - (prev_dim < dim)) - - in_dims_ = tree_map(fix_dim, in_dims, self.in_dims, is_leaf=lambda x: x is None) + fix = lambda d, d_: d if (d is None or d_ is None) else d - (d_ < d) # type: ignore + in_dims_ = tree_map(fix, in_dims, self.in_dims, is_leaf=lambda x: x is None) # type: ignore out_dim = self.prim.batch_dim_rule(axis_data, in_dims_) # type: ignore return tree_map(lambda d, d_: d + (d_ < d), out_dim, self.out_dim) # type: ignore @@ -619,6 +624,13 @@ def _call_hi_primitive_jvp(primals, tangents, *, _prim): return out_primals_flat, out_tangents_flat ad.primitive_jvps[call_hi_primitive_p] = _call_hi_primitive_jvp +def _call_hi_primitive_transpose(cts_flat, *primals_flat, _prim): + cts = tree_unflatten(_prim.out_tree, cts_flat) + primals = tree_unflatten(_prim.in_tree, primals_flat) + none = _prim.transpose(cts, *primals) + assert none is None +ad.fancy_transposes[call_hi_primitive_p] = _call_hi_primitive_transpose + def _call_hi_primitive_dce(used_outs_flat, eqn): _prim = eqn.params['_prim'] used_out = tree_unflatten(_prim.out_tree, used_outs_flat) @@ -757,7 +769,7 @@ def _vjp_bwd_aval_mismatch_err(path, primal_aval, ct): return if isinstance(primal_aval, AbstractRef): primal_aval = primal_aval.inner_aval - expected = primal_aval.to_cotangent_aval() + expected = primal_aval.to_ct_aval() ct_aval = ct.aval if isinstance(ct, ad_util.SymbolicZero) else typeof(ct) if (not core.typematch(expected, ct_aval) and not _temporary_dtype_exception(expected, ct_aval) and @@ -769,7 +781,7 @@ def _vjp_bwd_aval_mismatch_err(path, primal_aval, ct): def _replace_none(primal_in_aval, maybe_ct): if maybe_ct is None: - return ad_util.Zero(primal_in_aval.to_cotangent_aval()) + return ad_util.Zero(primal_in_aval.to_ct_aval()) else: return maybe_ct @@ -849,5 +861,7 @@ class Static: val: Any class MappingSpec: pass -class HipSpec: - def to_lo(self): assert False, "must override" +class HiPspec: + def to_lo(self) -> HiPspec: assert False, "must override" + def to_tangent_spec(self) -> HiPspec: assert False, "must override" + def to_ct_spec(self) -> HiPspec: assert False, "must override" diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index 6e26e336b87c..5686763ad960 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -64,9 +64,6 @@ from jax._src.numpy import linalg as jnp_linalg from jax._src import random as jax_random -# mypy generates a lot of false positive due to re-assigned variables. -# mypy: disable-error-code="assignment, no-redef" - # The code in this file relies on the values of some flags that are defined by # jtu. Note that the following can not always be moved to a test file since # then the test file has to import jtu first (to define the flags) which is not @@ -2483,7 +2480,7 @@ def _make_reduce_harness(name, *, dtype=np.float32): # The dtype of first operand def reducer(*args): init_val = np.array(init_value, dtype=dtype) - init_values = [init_val] + init_values: list[np.ndarray] = [init_val] if nr_operands == 2: init_values.append(np.array(0, dtype=np.int32)) return lax.reduce(args[0:nr_operands], tuple(init_values), diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 784c8ecb3ff3..338ec9e23967 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -103,8 +103,10 @@ def linearize_subtrace(_f: Callable, _store: lu.Store, _is_vjp: bool, del linearize_trace, ans, tracers nzs_out = tuple(type(t) is not Zero for t in out_tangents) out_tangents = tuple(t for t, nz in zip(out_tangents, nzs_out) if nz) - out_tangents = map(partial(tangent_trace.to_jaxpr_tracer, source_info=source_info), out_tangents) # type: ignore[assignment] - jaxpr, consts = tangent_trace.to_jaxpr(out_tangents, debug_info.with_unknown_names(), source_info) # pyrefly: ignore[bad-argument-type] # pyrefly#2385 + out_tangents = map(partial(tangent_trace.to_jaxpr_tracer, source_info=source_info), # type: ignore + out_tangents) + jaxpr, consts = tangent_trace.to_jaxpr( + out_tangents, debug_info.with_unknown_names(), source_info) # type: ignore which_env = [(isinstance(c, pe.DynamicJaxprTracer) and getattr(c._trace, 'tag', None) is _tag) for c in consts] jaxpr = pe.move_envvars(jaxpr, tuple(which_env)) @@ -154,12 +156,12 @@ def linearize_jaxpr( ) -> tuple[core.ClosedJaxpr, int, Sequence[bool], Sequence[int | None], core.ClosedJaxpr]: if type(allow_fwds) is bool: allow_fwds = (allow_fwds,) * (len(jaxpr.consts) + len(jaxpr.jaxpr.invars)) - assert len(allow_fwds) == (len(jaxpr.consts) + len(jaxpr.jaxpr.invars)) # pyrefly: ignore[bad-argument-type] # pyrefly#2530 + assert len(allow_fwds) == (len(jaxpr.consts) + len(jaxpr.jaxpr.invars)) if type(instantiate) is bool: instantiate = (instantiate,) * len(jaxpr.jaxpr.outvars) - assert len(instantiate) == len(jaxpr.jaxpr.outvars) # pyrefly: ignore[bad-argument-type] # pyrefly#2530 - return _linearize_jaxpr(jaxpr, tuple(nonzeros), tuple(instantiate), # pyrefly: ignore[bad-argument-type] # pyrefly#2530 - tuple(allow_fwds), is_vjp) # pyrefly: ignore[bad-argument-type] # pyrefly#2530 + assert len(instantiate) == len(jaxpr.jaxpr.outvars) + return _linearize_jaxpr(jaxpr, tuple(nonzeros), tuple(instantiate), + tuple(allow_fwds), is_vjp) @weakref_lru_cache @source_info_util.reset_name_stack() @@ -453,7 +455,7 @@ class ValAccum(GradAccum): def __init__(self, aval, val=None): self.aval = aval - self.val = Zero(aval.to_cotangent_aval()) if val is None else val + self.val = Zero(aval.to_ct_aval()) if val is None else val ct_check(self, self.val) def __repr__(self): @@ -471,7 +473,7 @@ def ct_check(primal, ct): if config.disable_bwd_checks.value: return ct_aval = ct.aval if type(ct) is Zero else typeof(ct) - ct_aval_expected = primal.aval.to_cotangent_aval() # type: ignore + ct_aval_expected = primal.aval.to_ct_aval() # type: ignore if not core.typematch(ct_aval, ct_aval_expected, no_dtype_check=True): # TODO(yashkatariya, mattjj): Add primitive name here for # better error message? @@ -556,7 +558,7 @@ def to_primal_tangent_pair(self, val): tangent_zero = p2tz(val) return (val, tangent_zero) - def process_primitive(self, primitive, tracers, params): + def process_primitive(self, primitive, tracers, params, /): primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) if (all(type(t) is Zero for t in tangents_in) and primitive is not core.ref_p and primitive is not core.empty_ref_p and @@ -579,7 +581,7 @@ def cur_qdd(self, x): with core.set_current_trace(self.parent_trace): return core.cur_qdd(p) - def process_call(self, call_primitive, f, tracers, params): + def process_call(self, call_primitive, f, tracers, params, /): assert call_primitive.multiple_results primals, tangents = unzip2(map(self.to_primal_tangent_pair, tracers)) which_nz = [ type(t) is not Zero for t in tangents] @@ -615,7 +617,7 @@ def new_out_axes_thunk(): def process_map(self, map_primitive, f, tracers, params): return self.process_call(map_primitive, f, tracers, params) - def process_custom_jvp_call(self, primitive, fun, jvp, tracers, symbolic_zeros): + def process_custom_jvp_call(self, primitive, fun, jvp, tracers, /, *, symbolic_zeros): primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) if all(type(t) is Zero for t in tangents_in): return primitive.bind_with_trace(self.parent_trace, (fun, jvp, *primals_in), @@ -631,7 +633,7 @@ def process_custom_jvp_call(self, primitive, fun, jvp, tracers, symbolic_zeros): tangents_out = map(replace_rule_output_symbolic_zeros, tangents_out) return map(partial(maybe_jvp_tracer, self), primals_out, tangents_out) - def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees, + def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, /, *, out_trees, symbolic_zeros): primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) if all(type(t) is Zero for t in tangents_in): @@ -791,7 +793,7 @@ def to_primal_tangent_pair(self, val): tangent_zero = p2tz(val) return (val, tangent_zero) - def process_primitive(self, primitive, tracers, params): + def process_primitive(self, primitive, tracers, params, /): primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) tangent_nzs = [type(t) is not Zero for t in tangents_in] if (all(type(t) is Zero for t in tangents_in) and @@ -818,7 +820,7 @@ def cur_qdd(self, x): return core.cur_qdd(p) def process_custom_jvp_call(self, primitive, fun: lu.WrappedFun, - jvp: lu.WrappedFun, tracers, *, + jvp: lu.WrappedFun, tracers, /, *, symbolic_zeros: bool): primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) if all(type(t) is Zero for t in tangents_in): @@ -845,7 +847,7 @@ def _f_jvp(primals, tangents): for x, nz, t in zip(primals_out, tangent_nzs_out, tangents_out)] def process_custom_vjp_call(self, primitive, fun, fwd, - bwd: lu.WrappedFun, tracers, + bwd: lu.WrappedFun, tracers, /, *, out_trees: Callable[[], tuple[PyTreeDef, PyTreeDef, list[int | None]]], symbolic_zeros: bool): primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) @@ -875,7 +877,7 @@ def process_custom_vjp_call(self, primitive, fun, fwd, tangent_nzs_out = [type(t) is not Zero for t in tangents_out] return map(partial(maybe_linearize_tracer, self), primals_out, tangent_nzs_out, tangents_out) - def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params): + def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params, /): assert call_primitive.multiple_results primals, tangents = unzip2(map(self.to_primal_tangent_pair, tracers)) nzs_in = tuple(type(t) is not Zero for t in tangents) @@ -1128,7 +1130,7 @@ def deflinear2(primitive, transpose_rule): def linear_transpose2(transpose_rule, cotangent, *args, **kwargs): if type(cotangent) is Zero: - return [Zero(x.aval.to_cotangent_aval()) if isinstance(x, UndefinedPrimal) + return [Zero(x.aval.to_ct_aval()) if isinstance(x, UndefinedPrimal) else None for x in args] else: return transpose_rule(cotangent, *args, **kwargs) @@ -1177,13 +1179,13 @@ def bilinear_transpose(lhs_rule, rhs_rule, cotangent, x, y, **kwargs): assert is_undefined_primal(x) ^ is_undefined_primal(y) if is_undefined_primal(x): if type(cotangent) is Zero: - return Zero(x.aval.to_cotangent_aval()), None + return Zero(x.aval.to_ct_aval()), None else: out = lhs_rule(cotangent, x, y, **kwargs) return out, None else: if type(cotangent) is Zero: - return None, Zero(y.aval.to_cotangent_aval()) + return None, Zero(y.aval.to_ct_aval()) else: out = rhs_rule(cotangent, x, y, **kwargs) return None, out @@ -1331,7 +1333,7 @@ def jvp_jaxpr(jaxpr: core.ClosedJaxpr, nonzeros: Sequence[bool], ) -> tuple[core.ClosedJaxpr, list[bool]]: if type(instantiate) is bool: instantiate = (instantiate,) * len(jaxpr.out_avals) - return _jvp_jaxpr(jaxpr, tuple(nonzeros), tuple(instantiate)) # pyrefly: ignore[bad-argument-type] # pyrefly#2530 + return _jvp_jaxpr(jaxpr, tuple(nonzeros), tuple(instantiate)) @weakref_lru_cache def _jvp_jaxpr(jaxpr: core.ClosedJaxpr, diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 8337678c24fc..fd5358c2d93b 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -52,7 +52,7 @@ MakeIotaHandler = Callable[[AxisSize], Array] def to_elt(trace: BatchTrace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt: - from jax._src import hijax # type: ignore + from jax._src import hijax # pytype: disable=import-error handler = to_elt_handlers.get(type(x)) if handler: return handler(partial(to_elt, trace, get_idx), get_idx, x, spec) @@ -152,7 +152,7 @@ def _short_repr(self): @property def aval(self): - from jax._src import hijax # type: ignore + from jax._src import hijax # pytype: disable=import-error aval = core.get_aval(self.val) if self._trace.axis_data.spmd_name is not None: if config._check_vma.value: @@ -249,7 +249,7 @@ def cur_qdd(self, x): with core.set_current_trace(self.parent_trace): return core.cur_qdd(val) - def process_primitive(self, p, tracers, params): # pyrefly: ignore[bad-param-name-override] + def process_primitive(self, p, tracers, params, /): vals_in, dims_in = unzip2(map(self.to_batch_info, tracers)) args_not_mapped = all(bdim is not_mapped for bdim in dims_in) if p in fancy_primitive_batchers: @@ -271,7 +271,7 @@ def process_primitive(self, p, tracers, params): # pyrefly: ignore[bad-param-na else: raise NotImplementedError(f"Batching rule for '{p}' not implemented") - def process_call(self, call_primitive, f, tracers, params): + def process_call(self, call_primitive, f, tracers, params, /): assert call_primitive.multiple_results params = dict(params, name=params.get('name', f.__name__)) vals, dims = unzip2(map(self.to_batch_info, tracers)) @@ -282,7 +282,7 @@ def process_call(self, call_primitive, f, tracers, params): src = source_info_util.current() return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out())] - def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): + def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params, /): vals, dims = unzip2(map(self.to_batch_info, tracers)) # The logic for the dimension math below is as follows: # ╔═════════════╦════════════════════════════════════════╦═══════════╗ @@ -320,7 +320,7 @@ def new_out_axes_thunk(): src = source_info_util.current() return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out_)] - def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): # pyrefly: ignore[bad-param-name-override] + def process_custom_jvp_call(self, prim, fun, jvp, tracers, /, *, symbolic_zeros): in_vals, in_dims = unzip2(map(self.to_batch_info, tracers)) fun, out_dims1 = batch_subtrace(fun, self.tag, self.axis_data, in_dims) jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.tag, self.axis_data, in_dims) @@ -330,8 +330,8 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): src = source_info_util.current() return [BatchTracer(self, v, d, src) for v, d in zip(out_vals, out_dims)] - def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees, # pyrefly: ignore[bad-override] - symbolic_zeros): # pytype: disable=signature-mismatch + def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, /, *, out_trees, + symbolic_zeros): in_vals, in_dims = unzip2(map(self.to_batch_info, tracers)) fwd_in_dims = [d for in_dim in in_dims for d in [in_dim, not_mapped]] diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index eb6e3af21ed8..537b0290c707 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -66,7 +66,6 @@ import numpy as np -# mypy: ignore-errors map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip @@ -86,7 +85,7 @@ def _is_not_block_argument(x: IrValues) -> bool: return not isinstance(x, ir.BlockArgument) -def dense_int_elements(xs) -> ir.DenseIntElementsAttr: +def dense_int_elements(xs) -> ir.DenseElementsAttr: return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64)) dense_int_array = ir.DenseI64ArrayAttr.get @@ -94,8 +93,7 @@ def dense_int_elements(xs) -> ir.DenseIntElementsAttr: def i32_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), i) def i64_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), i) -def shape_tensor(sizes: Sequence[int | ir.RankedTensorType] - ) -> ir.RankedTensorType: +def shape_tensor(sizes: Sequence[int | ir.RankedTensorType]) -> IrValues: int1d = aval_to_ir_type(core.ShapedArray((1,), np.int32)) i32_type = aval_to_ir_type(core.ShapedArray((), np.int32)) def lower_dim(d): @@ -107,7 +105,7 @@ def lower_dim(d): return hlo.reshape(int1d, d) ds = map(lower_dim, sizes) if not ds: - return type_cast(ir.RankedTensorType, ir_constant(np.array([], np.int32))) + return ir_constant(np.array([], np.int32)) elif len(ds) == 1: # pyrefly: ignore[bad-argument-type] # pyrefly#2385 return ds[0] # pyrefly: ignore[bad-index] # pyrefly#2385 else: @@ -192,10 +190,10 @@ def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type: return ir_type_factory() def _array_ir_types(aval: core.ShapedArray) -> ir.Type: - aval = core.physical_aval(aval) # type: ignore + aval = core.physical_aval(aval) if not core.is_constant_shape(aval.shape): - return _dynamic_array_ir_types(aval) # type: ignore - return ir.RankedTensorType.get(aval.shape, dtype_to_ir_type(aval.dtype)) # type: ignore + return _dynamic_array_ir_types(aval) + return ir.RankedTensorType.get(aval.shape, dtype_to_ir_type(aval.dtype)) def _dynamic_array_ir_types(aval: core.ShapedArray) -> ir.Type: dyn_size = ir.ShapedType.get_dynamic_size() @@ -259,6 +257,7 @@ def ir_constant( A representation of the constant as an IR value or sequence of IR values. """ if const_lowering is not None: + # pyrefly: ignore[no-matching-overload] if np.shape(val) and (c_val := const_lowering.get((id(val), aval))) is not None: return c_val for t in type(val).__mro__: @@ -315,7 +314,7 @@ def _ndarray_constant_handler(val: np.ndarray | np.generic, for ax in range(val.ndim))] out = hlo.broadcast_in_dim( ir.RankedTensorType.get( - val.shape, dtype_to_ir_type(collapsed_val.dtype)), # type: ignore + val.shape, dtype_to_ir_type(collapsed_val.dtype)), _numpy_array_constant(collapsed_val), dense_int_array(other_axes)) # type: ignore return out @@ -330,7 +329,7 @@ def _ndarray_constant_handler(val: np.ndarray | np.generic, np.float16, np.float32, np.float64, np.complex64, np.complex128, np.bool_, np.longlong, dtypes.bfloat16]: - register_constant_handler(_scalar_type, _ndarray_constant_handler) # type: ignore + register_constant_handler(_scalar_type, _ndarray_constant_handler) def _python_scalar_handler(val, aval: core.AbstractValue | None): assert isinstance(aval, core.ShapedArray), aval @@ -398,7 +397,7 @@ def _numpy_array_attribute_handler(val: np.ndarray | np.generic) -> ir.Attribute np.float16, np.float32, np.float64, np.complex64, np.complex128, np.bool_, np.longlong, dtypes.bfloat16]: - register_attribute_handler(_scalar_type, _numpy_array_attribute_handler) # type: ignore + register_attribute_handler(_scalar_type, _numpy_array_attribute_handler) def _dtype_attribute_handler(dtype: np.dtype | np.generic) -> ir.Attribute: return ir.TypeAttr.get(dtype_to_ir_type(dtype)) @@ -979,7 +978,7 @@ def sharded_aval(aval: core.AbstractValue, return aval if not isinstance(aval, core.ShapedArray): raise NotImplementedError - return aval.update(sharding.shard_shape(aval.shape), sharding=None) # type: ignore + return aval.update(sharding.shard_shape(aval.shape), sharding=None) def eval_dynamic_shape(ctx: LoweringRuleContext, @@ -1091,7 +1090,7 @@ def _to_physical_op_sharding( axis_ctx.manual_axes): sharding = add_manual_axes(axis_ctx, sharding, aval.ndim) if config.use_shardy_partitioner.value: - return sharding._to_sdy_sharding(aval.ndim) # type: ignore + return sharding._to_sdy_sharding(aval.ndim) return sharding._to_xla_hlo_sharding(aval.ndim).to_proto() # type: ignore @@ -1333,7 +1332,7 @@ def lower_jaxpr_to_module( raise ValueError( "Cannot lower jaxpr with verifier errors. " + dump_module_message(ctx.module, "verification")) - except ir.MLIRError as e: + except ir.MLIRError as e: # pyrefly: ignore[missing-attribute] msg_lines = ["Cannot lower jaxpr with verifier errors:"] def emit_diagnostic_info(d): msg_lines.append(f"\t{d.message}") @@ -1364,7 +1363,7 @@ def _set_up_aliases(input_output_aliases, avals_in, avals_out, donated_args, arg_memory_kinds, result_memory_kinds, in_layouts, out_layouts, result_shardings): if input_output_aliases is None: - input_output_aliases = [None] * len(avals_in) + input_output_aliases: list[int | None] = [None] * len(avals_in) else: input_output_aliases = list(input_output_aliases) # To match-up in-avals to out-avals we only care about the number of @@ -1437,7 +1436,7 @@ def _set_up_aliases(input_output_aliases, avals_in, avals_out, results_not_matched = collections.defaultdict(collections.deque) for i, (aval, rm) in enumerate(zip(avals_out, result_memory_kinds)): if i not in aliased_output_ids and aval is not core.abstract_token: - results_not_matched[(aval.size, rm)].append(i) + results_not_matched[(aval.size, rm)].append(i) # pyrefly: ignore[missing-attribute] # For each donated argument that hasn't been aliased or donated to XLA, try to # find an output array with matching size ignoring shapes. If a matching @@ -1450,7 +1449,11 @@ def _set_up_aliases(input_output_aliases, avals_in, avals_out, # then try to find an output array with matching size. if (out_donated_args[input_idx] and avals_in[input_idx] is not core.abstract_token): # pyrefly: ignore[bad-index] # pyrefly#2385 - key = (avals_in[input_idx].size, arg_memory_kinds[input_idx]) # pyrefly: ignore[bad-index] # pyrefly#2385 + key = ( + # pyrefly: ignore[missing-attribute] + avals_in[input_idx].size, # pyrefly: ignore[bad-index] # pyrefly#2385 + arg_memory_kinds[input_idx], + ) if results_not_matched.get(key, ()): # XLA donate the argument because there's a matching output array. results_not_matched[key].popleft() @@ -1597,7 +1600,7 @@ def lower_jaxpr_to_fun( token_types = [token_type() for _ in effects] token_avals = [core.abstract_token] * num_tokens # Order of arguments: dim vars, tokens, const_args, array inputs - input_avals = dim_var_avals + token_avals + list(in_avals) # type: ignore + input_avals = dim_var_avals + token_avals + list(in_avals) input_types = [*dim_var_types, *token_types, *input_types] output_avals = [core.abstract_token] * num_tokens + jaxpr.out_avals output_types = [*token_types, *output_types] @@ -2052,7 +2055,7 @@ def write(v: core.Var, node: IrValues): foreach(write, jaxpr.constvars, consts_for_constvars) foreach(write, jaxpr.invars, args) last_used = core.last_used(jaxpr) - if jaxlib_extension_version >= 409: + if jaxlib_extension_version >= 413: outer_traceback = outer_traceback or xc.Traceback() else: outer_traceback = None @@ -2065,8 +2068,8 @@ def write(v: core.Var, node: IrValues): tokens_in = tokens.subset(ordered_effects) eqn_name_stack = name_stack + eqn.source_info.name_stack - if jaxlib_extension_version >= 409: - traceback = (eqn.source_info.traceback or xc.Traceback()) + outer_traceback + if jaxlib_extension_version >= 413: + traceback = (eqn.source_info.traceback or xc.Traceback()) + outer_traceback # pyrefly: ignore[unsupported-operation] else: traceback = eqn.source_info.traceback loc = source_info_to_location(ctx, eqn.primitive, eqn_name_stack, traceback) @@ -2102,6 +2105,7 @@ def write(v: core.Var, node: IrValues): assert len(out_nodes) == len(eqn.outvars), (out_nodes, eqn) if ordered_effects: + assert tokens_out is not None tokens = tokens.update_tokens(tokens_out) foreach(write, eqn.outvars, out_nodes) @@ -2109,6 +2113,16 @@ def write(v: core.Var, node: IrValues): return tuple(read(v) for v in jaxpr.outvars), tokens +class CachedLoweringRule(Protocol): + def __call__( + self, + ctx: LoweringRuleContext, + *args: ir.Value | Sequence[ir.Value], + **kwargs: Any, + ) -> tuple[Sequence[ir.Value | Sequence[ir.Value]], bool]: + ... + + def _cached_lowering( ctx: ModuleContext, eqn: core.JaxprEqn, @@ -2144,9 +2158,9 @@ def _cached_lowering( avals_out = map(lambda v: v.aval, eqn.outvars) cache_entry = _emit_lowering_rule_as_fun( partial(_uncached_lowering, eqn.primitive, eqn.ctx, eqn.effects), - ctx, eqn.ctx, eqn.primitive, ordered_effects, avals_in, avals_out, + ctx, eqn.ctx, eqn.primitive, ordered_effects, avals_in, avals_out, # pyrefly: ignore[bad-argument-type] # pyrefly#2385 **params, - ) # pyrefly: ignore[bad-argument-type] # pyrefly#2385 + ) ctx.lowering_cache[cache_key] = cache_entry tokens_in_args = tuple(tokens_in.get(eff) for eff in ordered_effects) @@ -2171,7 +2185,7 @@ def _cached_lowering( def _emit_lowering_rule_as_fun( - lowering_rule: LoweringRule, + lowering_rule: CachedLoweringRule, ctx: ModuleContext, eqn_ctx: core.JaxprEqnContext, primitive: core.Primitive, @@ -2219,15 +2233,19 @@ def _emit_lowering_rule_as_fun( traceback=None, avals_in=avals_in, avals_out=avals_out, tokens_in=TokenSet(zip(ordered_effects, token_args)), - tokens_out=None, jaxpr_eqn_ctx=eqn_ctx, dim_var_values=dim_var_values, + tokens_out=None, jaxpr_eqn_ctx=eqn_ctx, + dim_var_values=flatten_ir_values(dim_var_values), const_lowering=const_lowering) with source_info_to_location( ctx, primitive, source_info_util.new_name_stack(), None ): outs, inline = lowering_rule(sub_ctx, *unflattened_args, **params) if sub_ctx.tokens_out: - outs = [*[sub_ctx.tokens_out.get(eff) for eff in ordered_effects], *outs] - outs = flatten_ir_values(outs) + outs = [ + *(sub_ctx.tokens_out.get(eff) for eff in ordered_effects), + *outs + ] + outs = flatten_ir_values(outs) # pyrefly: ignore[bad-argument-type] func_dialect.return_(outs) return LoweringCacheValue(func_op, output_types, const_args, const_arg_avals, inline) @@ -2393,16 +2411,18 @@ def lower_per_platform(ctx: LoweringRuleContext, assert kept_rules # If there is a single rule left just apply the rule, without conditionals. if len(kept_rules) == 1: - output = kept_rules[0](ctx, *rule_args, **rule_kwargs) + output = type_cast( + Sequence[IrValues], kept_rules[0](ctx, *rule_args, **rule_kwargs) + ) + flat_output = flatten_ir_values(output) foreach( lambda o: wrap_compute_type_in_place(ctx, _get_owner(o)), - filter(_is_not_block_argument, flatten_ir_values(output)), + filter(_is_not_block_argument, flat_output), ) foreach( - lambda o: wrap_xla_metadata_in_place(ctx, _get_owner(o)), - flatten_ir_values(output), + lambda o: wrap_xla_metadata_in_place(ctx, _get_owner(o)), flat_output ) - return output + return flat_output assert len(platforms) > 1 and len(kept_rules) >= 2, (platforms, kept_rules) assert len(ctx.dim_var_values) >= 1, "Must have a platform_index variable" @@ -2433,7 +2453,9 @@ def lower_per_platform(ctx: LoweringRuleContext, inner_ctx = ctx.replace(platforms=platforms_for_this_rule) branch = case_op.regions[i].blocks.append() with ir.InsertionPoint(branch): - output = rule(inner_ctx, *rule_args, **rule_kwargs) + output = type_cast( + Sequence[IrValues], rule(inner_ctx, *rule_args, **rule_kwargs) + ) try: out_nodes = flatten_ir_values(output) except TypeError as e: @@ -2483,7 +2505,7 @@ def f_lowered(ctx: LoweringRuleContext, *args, **params): wrapped_fun, ctx.avals_in, lower=True) if any(isinstance(e, core.InternalMutableArrayEffect) for e in jaxpr.effects): - from jax._src.interpreters import pxla # type: ignore + from jax._src.interpreters import pxla # pytype: disable=import-error closed_jaxpr = core.ClosedJaxpr(jaxpr, consts_for_constvars) closed_jaxpr = pxla._discharge_internal_refs(closed_jaxpr) jaxpr, consts_for_constvars = closed_jaxpr.jaxpr, closed_jaxpr.consts @@ -2837,10 +2859,13 @@ def full_like_aval(ctx: LoweringRuleContext, value, aval: core.ShapedArray) -> i return broadcast_in_dim(ctx, zero, aval, broadcast_dimensions=()) def add_jaxvals_lowering(ctx, x, y): + out_aval, = ctx.avals_out if (isinstance(a := ctx.avals_in[0], core.ShapedArray) and dtypes.issubdtype(a.dtype, dtypes.extended)): return lower_fun(lambda x, y: [a.dtype._rules.add(a.dtype, x, y)])(ctx, x, y) - return [hlo.add(x, y)] + out = hlo.add(x, y) + return [lower_with_sharding_in_types(ctx, out, out_aval)] + register_lowering(ad_util.add_jaxvals_p, add_jaxvals_lowering) register_lowering(ad_util.stop_gradient_p, lambda ctx, x: [x]) @@ -2974,15 +2999,15 @@ def get_sharding_attr( sharding: xc.OpSharding | SdyArray | SdyArrayList ) -> ir.Attribute: if isinstance(sharding, (SdyArray, SdyArrayList)): - return sharding.build() # type: ignore + return sharding.build() else: # If there are very large numbers of devices, use the proto representation. # The MHLO to HLO conversion supports both, and the proto representation is # more compact. - if len(sharding.tile_assignment_devices) > 100: # type: ignore - return ir.StringAttr.get(sharding.SerializeToString()) # type: ignore + if len(sharding.tile_assignment_devices) > 100: + return ir.StringAttr.get(sharding.SerializeToString()) else: - return ir.StringAttr.get(repr(xc.HloSharding.from_proto(sharding))) # type: ignore[arg-type] + return ir.StringAttr.get(repr(xc.HloSharding.from_proto(sharding))) def wrap_with_layout_op(ctx: LoweringRuleContext, diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index cedf24b94a8b..33c278b29625 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -44,7 +44,6 @@ ClosedJaxpr, new_jaxpr_eqn, Var, DropVar, Atom, JaxprEqn, Primitive, mapped_aval, unmapped_aval, get_referent, JaxprEqnContext, typeof) from jax._src.lib import _jax -from jax._src.lib import jaxlib_extension_version from jax._src.source_info_util import SourceInfo from jax._src.state.types import AbstractRef, ReadEffect from jax._src.tree_util import FlatTree, PyTreeDef, treedef_tuple @@ -67,11 +66,7 @@ def identity(x): return x PyTree = Any logger = logging.getLogger(__name__) -# TODO(phawkins): remove after jaxlib 0.9.1 is the minimum -if jaxlib_extension_version >= 406: - TracebackScope = _jax.TracebackScope -else: - TracebackScope = contextlib.nullcontext # type: ignore +TracebackScope = _jax.TracebackScope class PartialVal(tuple): @@ -183,7 +178,7 @@ def cur_qdd(self, x): with core.set_current_trace(self.parent_trace): return core.cur_qdd(const) - def process_primitive(self, primitive, tracers, params): + def process_primitive(self, primitive, tracers, params, /): with core.set_current_trace(self.parent_trace): if primitive in custom_partial_eval_rules: tracers = map(self.to_jaxpr_tracer, tracers) @@ -222,7 +217,7 @@ def default_process_primitive(self, primitive, tracers, params): out_tracer.recipe = eqn return out_tracer - def process_call(self, primitive, f: lu.WrappedFun, tracers, params): # pyrefly: ignore[bad-param-name-override] + def process_call(self, primitive, f: lu.WrappedFun, tracers, params, /): tracers = map(self.to_jaxpr_tracer, tracers) rule = call_partial_eval_rules.get(primitive) if rule: @@ -264,7 +259,7 @@ def process_call(self, primitive, f: lu.WrappedFun, tracers, params): # pyrefly num_new_args = len(res_tracers) + len(env_tracers) # pyrefly: ignore[bad-argument-type] # pyrefly#2385 new_jaxpr = convert_constvars_jaxpr(jaxpr) if isinstance(primitive, core.ClosedCallPrimitive): - new_jaxpr = close_jaxpr(new_jaxpr) # type: ignore + new_jaxpr = close_jaxpr(new_jaxpr) staged_params = dict(params, call_jaxpr=new_jaxpr) staged_params = update_params(staged_params, map(op.not_, in_knowns), num_new_args) @@ -278,7 +273,7 @@ def process_call(self, primitive, f: lu.WrappedFun, tracers, params): # pyrefly for t in out_tracers: t.recipe = eqn return merge_lists(out_knowns, out_tracers, out_consts) - def process_map(self, primitive, f: lu.WrappedFun, tracers, params): # pyrefly: ignore[bad-param-name-override] + def process_map(self, primitive, f: lu.WrappedFun, tracers, params, /): tracers = map(self.to_jaxpr_tracer, tracers) update_params = call_param_updaters.get(primitive) or (lambda p, _, __: p) in_knowns, in_avals, in_consts = partition_pvals([t.pval for t in tracers]) @@ -350,7 +345,7 @@ def const_out_axes_thunk(): def _current_truncated_name_stack(self): return source_info_util.current_name_stack()[len(self.name_stack):] - def process_custom_jvp_call(self, prim, fun, jvp, tracers, symbolic_zeros): # pyrefly: ignore[bad-override] + def process_custom_jvp_call(self, prim, fun, jvp, tracers, /, *, symbolic_zeros): tracers = map(self.to_jaxpr_tracer, tracers) if all(t.is_known() for t in tracers): with core.set_current_trace(self.parent_trace): @@ -362,7 +357,7 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, symbolic_zeros): # p with core.set_current_trace(self): return fun.call_wrapped(*tracers) - def process_custom_transpose(self, prim, call, tracers, **params): + def process_custom_transpose(self, prim, call, tracers, /, **params): tracers = map(self.to_jaxpr_tracer, tracers) res_ts, lin_ts = split_list(tracers, [params['res_tree'].num_leaves]) # pyrefly: ignore[bad-argument-type] # pyrefly#2385 assert all(t.is_known() for t in res_ts) @@ -381,7 +376,7 @@ def process_custom_transpose(self, prim, call, tracers, **params): for t in out_tracers: t.recipe = eqn return out_tracers - def process_custom_vjp_call(self, prim, f, fwd, bwd, tracers, out_trees, symbolic_zeros): # pyrefly: ignore[bad-param-name-override] + def process_custom_vjp_call(self, prim, f, fwd, bwd, tracers, /, *, out_trees, symbolic_zeros): tracers = map(self.to_jaxpr_tracer, tracers) if all(t.is_known() for t in tracers): vals = [t.pval[1] for t in tracers] @@ -794,7 +789,7 @@ def sort_key(t): env_vars, env_vals = unzip2(env.items()) invars = [*env_vars, *map(get_atom, in_tracers)] const_vars, const_vals = unzip2(consts.items()) - outvars = map(get_atom, out_tracers) # type: ignore[arg-type] + outvars = map(get_atom, out_tracers) jaxpr_effects = make_jaxpr_effects(const_vars, invars, outvars, eqns) is_high |= any(x.aval.is_high for x in it.chain(const_vars, invars, outvars)) # pyrefly: ignore[bad-argument-type] # pyrefly#2385 jaxpr = Jaxpr(const_vars, invars, # type: ignore[arg-type] @@ -1012,9 +1007,9 @@ def partial_eval_jaxpr_stateful( saveable = everything_saveable jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref = \ _partial_eval_jaxpr_custom_cached( - # pyrefly: ignore[bad-argument-type] # pyrefly#2530 + jaxpr, tuple(in_unknowns), tuple(in_inst), tuple(ensure_out_unknowns), - # pyrefly: ignore[bad-argument-type] # pyrefly#2530 + tuple(ensure_out_inst), saveable) return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref @@ -1079,7 +1074,7 @@ def has_effects(effects) -> bool: foreach(partial(write, False, False), eqn.outvars) elif isinstance(policy, Offloadable): # TODO(slebedev): This is a legit error which requires a BUILD fix. - from jax._src.dispatch import device_put_p, ArrayCopySemantics # type: ignore + from jax._src.dispatch import device_put_p, ArrayCopySemantics resvars = [Var(v.aval.update(memory_space=core.mem_kind_to_space(policy.dst))) for v in eqn.outvars] offload_eqn = core.JaxprEqn( @@ -1410,7 +1405,7 @@ def dce_jaxpr(jaxpr: Jaxpr, used_outputs: Sequence[bool], """ if type(instantiate) is bool: instantiate = (instantiate,) * len(jaxpr.invars) - # pyrefly: ignore[bad-argument-type] # pyrefly#2530 + return _dce_jaxpr(jaxpr, tuple(used_outputs), tuple(instantiate)) @@ -1604,7 +1599,7 @@ def move_binders_to_back(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool] class DynamicJaxprTracer(core.Tracer): - __slots__ = ['aval', 'val', 'mutable_qdd', 'parent', '_debug_info'] + __slots__ = ['_aval', 'val', 'mutable_qdd', 'parent', '_debug_info'] _trace: DynamicJaxprTrace # pyrefly: ignore[bad-override] @@ -1623,7 +1618,7 @@ def __init__(self, trace: DynamicJaxprTrace, self._trace = trace self._line_info = line_info self._debug_info = self._trace.frame.debug_info # for UnexpectedTracerError - self.aval = aval # type: ignore[misc] + self._aval = aval self.val = val self.mutable_qdd = core.MutableQuasiDynamicData(qdd) self.parent = parent @@ -1634,6 +1629,10 @@ def _short_repr(self): def cur_qdd(self): return self.mutable_qdd.cur_val + @property + def aval(self): + return self._aval + @property def aval_mutable_qdd(self): aval = self.aval @@ -1724,7 +1723,7 @@ def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects: f"`JaxprInputEffect` {eff} is invalid." f"\n Equation: {eqn}\n" "\n Jaxpr: " - f"{core.Jaxpr(constvars, invars, outvars, eqns, set(), dbg)}") # type: ignore + f"{core.Jaxpr(constvars, invars, outvars, eqns, set(), dbg)}") eqn_invar = eqn.invars[eff.input_index] if type(eqn_invar) is core.Literal or eqn_invar in mut_arrays: continue @@ -1740,7 +1739,7 @@ def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects: f"\n Equation: {eqn}\n" f"\n Effects: {eqn.effects}\n" "\n Jaxpr: " - f"{core.Jaxpr(constvars, invars, outvars, eqns, set(), dbg)}") # type: ignore + f"{core.Jaxpr(constvars, invars, outvars, eqns, set(), dbg)}") eff = eff.replace(input_index=input_index) jaxpr_effects.add(eff) return jaxpr_effects @@ -2041,7 +2040,7 @@ def cur_qdd(self, x): source_info = source_info_util.current() return self.to_jaxpr_tracer(x, source_info=source_info).mutable_qdd.cur_val - def process_primitive(self, primitive, tracers, params): + def process_primitive(self, primitive, tracers, params, /): self.frame.is_high |= primitive.is_high(*map(typeof, tracers), **params) if config.eager_constant_folding.value and not any(isinstance(x, Tracer) for x in tracers): return primitive.bind_with_trace(core.eval_trace, tracers, params) @@ -2101,8 +2100,8 @@ def default_process_primitive(self, primitive, tracers, params, self.frame.add_eqn(eqn) # pyrefly: ignore[bad-argument-type] return out_tracers if primitive.multiple_results else out_tracers.pop() - def process_call(self, call_primitive, f: lu.WrappedFun, in_tracers, # pyrefly: ignore[bad-param-name-override] - params): + def process_call(self, call_primitive, f: lu.WrappedFun, in_tracers, + params, /): source_info = source_info_util.current() to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) in_type = (tuple(get_aval(t) for t in in_tracers) if f.in_type is None @@ -2118,7 +2117,7 @@ def process_call(self, call_primitive, f: lu.WrappedFun, in_tracers, # pyrefly: new_jaxpr = convert_constvars_jaxpr(jaxpr) if isinstance(call_primitive, core.ClosedCallPrimitive): - new_jaxpr = close_jaxpr(new_jaxpr) # type: ignore + new_jaxpr = close_jaxpr(new_jaxpr) new_params = dict(params, call_jaxpr=new_jaxpr) update_params = call_param_updaters.get(call_primitive) if update_params: @@ -2129,7 +2128,7 @@ def process_call(self, call_primitive, f: lu.WrappedFun, in_tracers, # pyrefly: [*const_tracers, *in_tracers], out_avals, call_primitive, new_params, new_params['call_jaxpr'].effects, source_info=source_info) - def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): + def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params, /): source_info = source_info_util.current() to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) tracers = map(to_jaxpr_tracer, tracers) @@ -2163,8 +2162,8 @@ def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): [*const_tracers, *tracers], out_avals, map_primitive, new_params, effs, source_info=source_info) return out_tracers - def process_custom_jvp_call(self, prim, fun: lu.WrappedFun, # pyrefly: ignore[bad-override] - jvp: lu.WrappedFun, tracers, + def process_custom_jvp_call(self, prim, fun: lu.WrappedFun, + jvp: lu.WrappedFun, tracers, /, *, symbolic_zeros: bool): if config.eager_constant_folding.value and not any(isinstance(x, Tracer) for x in tracers): return prim.bind_with_trace(core.eval_trace, (fun, jvp, *tracers), @@ -2198,9 +2197,9 @@ def jvp_jaxpr_thunk(*in_zeros): fun_jaxpr.effects, source_info=source_info) - def process_custom_vjp_call(self, prim: core.Primitive, # pyrefly: ignore[bad-param-name-override] + def process_custom_vjp_call(self, prim: core.Primitive, fun: lu.WrappedFun, - fwd: lu.WrappedFun, bwd: lu.WrappedFun, tracers, + fwd: lu.WrappedFun, bwd: lu.WrappedFun, tracers, /, *, out_trees: Callable[[], tuple[PyTreeDef, PyTreeDef, list[int | None]]], symbolic_zeros: bool): if config.eager_constant_folding.value and not any(isinstance(x, Tracer) for x in tracers): @@ -2238,7 +2237,7 @@ def out_trees_(): fun_jaxpr.effects, source_info=source_info) - def process_custom_transpose(self, prim: core.Primitive, # type: ignore[override] + def process_custom_transpose(self, prim: core.Primitive, # pyrefly: ignore[bad-override] call: lu.WrappedFun, tracers, *, transpose: lu.WrappedFun, out_types, @@ -2456,13 +2455,9 @@ def trace_to_jaxpr( # TODO(dougalm): remove in favor of `trace_to_jaxpr` @profiler.annotate_function def trace_to_jaxpr_dynamic( - fun: lu.WrappedFun, - in_avals: Sequence[AbstractValue | core.AvalQDD], - *, - keep_inputs: list[bool] | None = None, - lower: bool = False, - auto_dce: bool = False, -) -> tuple[Jaxpr, list[AbstractValue], list[Any]]: + fun: lu.WrappedFun, in_avals: Sequence[AbstractValue | core.AvalQDD], + *, keep_inputs: list[bool] | None = None, lower: bool = False, + auto_dce: bool = False) -> tuple[Jaxpr, list[AbstractValue], list[Any]]: config.enable_checks.value and fun.debug_info.assert_arg_names(len(in_avals)) keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs parent_trace = core.trace_ctx.trace @@ -2472,15 +2467,11 @@ def trace_to_jaxpr_dynamic( # equations should be rooted at the enclosing jaxpr and not contain any # context from the callsite. Otherwise metadata from one caller would bleed # into metadata from a different caller if we, e.g., inline. - with ( - core.ensure_no_leaks(trace), - source_info_util.reset_name_stack(), - TracebackScope(), - ): + with (core.ensure_no_leaks(trace), source_info_util.reset_name_stack(), + TracebackScope()): source_info = source_info_util.current() in_tracers = map(partial(trace.new_arg, source_info=source_info), in_avals) in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep] - with core.set_current_trace(trace): ans = fun.call_wrapped(*in_tracers) _check_returned_jaxtypes(fun.debug_info, ans) @@ -2489,7 +2480,6 @@ def trace_to_jaxpr_dynamic( jaxpr, consts = trace.frame.to_jaxpr(trace, out_tracers, fun.debug_info, # pyrefly: ignore[bad-argument-type] # pyrefly#2385 source_info) del trace, fun, in_tracers, out_tracers, ans - config.enable_checks.value and core.check_jaxpr(jaxpr) return jaxpr, [v.aval for v in jaxpr.outvars], consts diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 8d86327e2a05..f54ed143ffc0 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -476,7 +476,7 @@ def to_map_tracer(self, val): else: return MapTracer(self, val, {}) - def process_primitive(self, primitive, tracers, params): + def process_primitive(self, primitive, tracers, params, /): from jax._src.lax import parallel # pytype: disable=import-error if primitive is parallel.axis_index_p: return self.process_axis_index(**params) # pytype: disable=missing-parameter @@ -500,10 +500,10 @@ def process_primitive(self, primitive, tracers, params): return [MapTracer(self, val, out_shard_axes) for val in outvals] return MapTracer(self, outvals, out_shard_axes) - def process_call(self, call_primitive, fun, tracers, params): + def process_call(self, call_primitive, fun, tracers, params, /): raise NotImplementedError - def process_map(self, map_primitive, fun, tracers, params): + def process_map(self, map_primitive, fun, tracers, params, /): if params['devices'] is not None: raise ValueError("Nested pmap with explicit devices argument.") if not config.disable_jit.value: @@ -528,7 +528,7 @@ def process_map(self, map_primitive, fun, tracers, params): for v, s, dst in zip(out, outaxes, out_axes_thunk())) return map(partial(MapTracer, self), out, outaxes) - def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): + def process_custom_jvp_call(self, prim, fun, jvp, tracers, /, *, symbolic_zeros): if symbolic_zeros: msg = ("custom_jvp with symbolic_zeros=True not supported with eager pmap. " "Please open an issue at https://github.com/jax-ml/jax/issues !") @@ -537,7 +537,7 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): with core.set_current_trace(self): return fun.call_wrapped(*tracers) - def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, + def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, /, *, out_trees, symbolic_zeros): if symbolic_zeros: msg = ("custom_vjp with symbolic_zeros=True not supported with eager pmap. " @@ -575,8 +575,8 @@ def _match_annot(axis_name: core.AxisName, axis_size: int, val: Any, if src == dst: outval = val elif type(src) == type(dst) == int: - outval = batching.moveaxis(val, src, dst) - shard_axis_out = _moveaxis(np.ndim(val), shard_axis_src, src, dst) + outval = batching.moveaxis(val, src, dst) # pyrefly: ignore[bad-argument-type] # pyrefly#2530 + shard_axis_out = _moveaxis(np.ndim(val), shard_axis_src, src, dst) # pyrefly: ignore[bad-argument-type] # pyrefly#2530 elif src is None and dst is not None: outval = batching.broadcast(val, axis_size, dst, None) shard_axis_out = {n: d + (dst <= d) for n, d in shard_axis_out.items()} @@ -1414,7 +1414,7 @@ def _pmap_partial_eval_custom_params_updater( def _pmap_partial_eval_custom_res_maker(params_known, aval): return core.unmapped_aval(params_known['axis_size'], 0, aval) -def _pmap_dce_rule(used_outputs, eqn): +def _pmap_dce_rule(used_outputs, eqn: core.JaxprEqn): # just like pe.dce_jaxpr_call_rule, except handles in_axes / out_axes if not any(used_outputs) and not pe.has_effects(eqn): return [False] * len(eqn.invars), None @@ -1872,7 +1872,7 @@ def __init__(self, shardings: tuple[GSPMDSharding | UnspecifiedValue, ...], gspmd_shardings = [ s if (isinstance(s, (UnspecifiedValue, AUTO)) or (isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh))) - else to_gspmd_sharding(s, a.ndim) # pytype: disable=attribute-error + else to_gspmd_sharding(s, a.ndim) # type: ignore[missing-attribute] for s, a in zip(shardings, avals)] self._gspmd_shardings = gspmd_shardings self.shardings = shardings @@ -2179,7 +2179,7 @@ def _concretize_abstract_out_shardings(shardings, avals, device_assignment, if device_assignment is None: return shardings - out = [] + out: list[UnspecifiedValue | JSharding] = [] for s, a, mem_kind in zip(shardings, avals, out_mem_kinds): if isinstance(s, UnspecifiedValue) and isinstance(a, core.ShapedArray): if a.sharding.mesh.empty: @@ -2236,8 +2236,7 @@ def lower_sharding_computation( number of out_avals might not be known at that time and lower_sharding_computation calculates the number of out_avals so it can apply the singleton UNSPECIFIED to all out_avals.""" - auto_spmd_lowering = check_if_any_auto( - it.chain.from_iterable([in_shardings, out_shardings])) + auto_spmd_lowering = check_if_any_auto(it.chain(in_shardings, out_shardings)) all_args_info = AllArgsInfo(closed_jaxpr.in_avals, closed_jaxpr.jaxpr._debug_info) @@ -2583,6 +2582,7 @@ def get_out_shardings_from_executable( return [sharding_impls.GSPMDSharding.get_replicated(device_list, memory_kind=mk) for mk in omk] + out_op_shardings: Sequence[xc.OpSharding] _, out_op_shardings = get_op_sharding_from_executable(xla_executable) if not out_op_shardings: return None @@ -2671,7 +2671,7 @@ def _gspmd_to_single_device_sharding( def _get_out_sharding_from_orig_sharding( out_shardings, out_avals, orig_in_s, orig_aval): - out = [] + out: list[JSharding] = [] orig_handler = _orig_out_sharding_handlers[type(orig_in_s)] for o, out_aval in safe_zip(out_shardings, out_avals): if (isinstance(o, sharding_impls.GSPMDSharding) and @@ -2913,7 +2913,7 @@ def _maybe_get_and_check_out_shardings( dtypes.issubdtype(aval.dtype, dtypes.extended)): xla_s = sharding_impls.logical_sharding(aval.shape, aval.dtype, xla_s) try: - new_out_shardings.append(_gspmd_to_named_sharding(xla_s, aval, orig)) # pytype: disable=wrong-arg-types + new_out_shardings.append(_gspmd_to_named_sharding(xla_s, aval, orig)) # type: ignore[arg-type] except: new_out_shardings.append(xla_s) else: @@ -3054,7 +3054,7 @@ def from_hlo(name: str, mesh = None if auto_spmd_lowering: - for i in it.chain.from_iterable([in_shardings, out_shardings]): + for i in it.chain(in_shardings, out_shardings): # pyrefly: ignore[bad-argument-type] if isinstance(i, AUTO): mesh = i.mesh break @@ -3232,12 +3232,12 @@ def xla_extension_executable(self): return self.xla_executable def call(self, *args): - args_after_dce = [a for i, a in enumerate(args) if i in self._kept_var_idx] + args_after_dce = tuple(a for i, a in enumerate(args) if i in self._kept_var_idx) if (self._all_args_info is not None and self._all_args_info.debug_info.arg_names is not None): - arg_names_after_dce = [ + arg_names_after_dce = tuple( n for i, n in enumerate(self._all_args_info.debug_info.arg_names) - if i in self._kept_var_idx] + if i in self._kept_var_idx) else: arg_names_after_dce = ("",) * len(args_after_dce) @@ -3281,6 +3281,7 @@ def aot_cache_miss(*args, **kwargs): use_fastpath = (all(isinstance(x, xc.ArrayImpl) for x in out_flat) and not self._mut) else: + out_tree_dispatch = None use_fastpath = False if use_fastpath: diff --git a/jax/_src/interpreters/remat.py b/jax/_src/interpreters/remat.py index 72100287d9ec..49007a14ad02 100644 --- a/jax/_src/interpreters/remat.py +++ b/jax/_src/interpreters/remat.py @@ -59,8 +59,10 @@ def f_rem(rs, *args): return out_ft.unflatten(), Partial(f_rem, map(reduce_precision, rs)) class RematTracer(core.Tracer): + _trace: RematTrace # pyrefly: ignore[bad-override] + def __init__(self, trace, x, jaxpr_tracer): - self._trace = trace # type: ignore + self._trace = trace # pytype: disable=name-error self.val = x self.tracer = jaxpr_tracer @@ -83,7 +85,7 @@ def to_val_tracer_pair(self, x): else: raise NotImplementedError # TODO(mattjj) - def process_primitive(self, prim, tracers, params): + def process_primitive(self, prim, tracers, params, /): in_vals, in_vals2 = unzip2(map(self.to_val_tracer_pair, tracers)) if prim in rules: with core.set_current_trace(self.parent_trace): @@ -123,7 +125,7 @@ def _remat_jaxpr(jaxpr, policy): src = source_info_util.current() def new_arg(a): - return RematTracer(trace, fwd_trace.new_arg(a, src), rem_trace.new_arg(a, src)) # type: ignore # noqa: F821 + return RematTracer(trace, fwd_trace.new_arg(a, src), rem_trace.new_arg(a, src)) # noqa: F821 # pytype: disable=name-error tracers = map(new_arg, jaxpr.in_aval_qdds) with core.set_current_trace(trace, check_leaks=True): @@ -139,4 +141,4 @@ def new_arg(a): [*out_primals, *rem_consts], dbg.with_unknown_names(), src) fwd_trace.invalidate() fwd_jaxpr = core.ClosedJaxpr(fwd_jaxpr_, fwd_consts) - return fwd_jaxpr, rem_jaxpr, len(rem_consts) + return fwd_jaxpr, rem_jaxpr, len(rem_consts) # pyrefly: ignore[bad-argument-type] # pyrefly#2385 diff --git a/jax/_src/lax/ann.py b/jax/_src/lax/ann.py index 117bed2eae2c..9c4c522c7b3b 100644 --- a/jax/_src/lax/ann.py +++ b/jax/_src/lax/ann.py @@ -299,7 +299,9 @@ def _approx_top_k_lowering(ctx, operand, *, k, init_arg = hlo.constant(ir.DenseElementsAttr.get(np.int32(-1))) init_val_array = _get_init_val_literal(ctx.avals_in[0].dtype, is_max_k) - init_val = mlir.ir_constant(init_val_array.reshape(())) + init_vals = mlir.flatten_ir_values( + [mlir.ir_constant(init_val_array.reshape(())) + ]) backend_config = { "reduction_dim" : mlir.i64_attr(reduction_dimension), @@ -313,16 +315,19 @@ def _approx_top_k_lowering(ctx, operand, *, k, if all(core.is_constant_shape(aval_out.shape) for aval_out in ctx.avals_out): result_shapes = None else: - result_shapes = [ + result_shapes = mlir.flatten_ir_values( mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, aval_out.shape)) - for aval_out in ctx.avals_out] + for aval_out in ctx.avals_out + ) if core.is_constant_dim(k): backend_config["top_k"] = mlir.i64_attr(k) out = mlir.custom_call( "ApproxTopK", - result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], - operands=[operand, iota, init_val, init_arg], + result_types=mlir.flatten_ir_types( + mlir.aval_to_ir_type(aval) for aval in ctx.avals_out + ), + operands=[operand, iota, *init_vals, init_arg], called_computations=[comparator.name.value], backend_config=backend_config, result_shapes=result_shapes) @@ -330,8 +335,10 @@ def _approx_top_k_lowering(ctx, operand, *, k, k_value, = mlir.eval_dynamic_shape_as_vals(ctx, (k,)) out = mlir.custom_call( "stablehlo.dynamic_approx_top_k", - result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], - operands=[operand, iota, init_val, init_arg, k_value], + result_types=mlir.flatten_ir_types( + mlir.aval_to_ir_type(aval) for aval in ctx.avals_out + ), + operands=[operand, iota, *init_vals, init_arg, k_value], called_computations=[comparator.name.value], backend_config=backend_config, result_shapes=result_shapes) diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 6907210bbe36..15f29899f771 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -1200,7 +1200,7 @@ def _platform_index_lowering(ctx: mlir.LoweringRuleContext, def lower_constant(ctx: mlir.LoweringRuleContext, *, i: int) -> Sequence[ir.Value]: v = mlir.ir_constant(np.int32(i)) - return [v] + return mlir.flatten_ir_values([v]) platform_rules: dict[str, mlir.LoweringRule] = {} default_rule = None diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 1a0339d6115c..4b2364928fb8 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -214,7 +214,7 @@ def scan(f, init, xs, length=None): args_avals = args.map(core.get_aval) init_avals, xs_avals = args_avals.unpack() - from jax._src.hijax import HiType # type: ignore + from jax._src.hijax import HiType if any(isinstance(a, HiType) for a in xs_avals): if length is None: raise ValueError("must provide `length` to `scan`") @@ -2048,7 +2048,7 @@ def fun(*args): hlo.return_([*mlir.flatten_ir_values(out_tokens), *mlir.flatten_ir_values(x), *mlir.flatten_ir_values(y), - *mlir.flatten_ir_values(new_z)]) + *mlir.flatten_ir_values(new_z)]) # pyrefly: ignore[bad-argument-type] outputs = mlir.unflatten_ir_values_like_types(while_op.results, loop_carry_types) tokens, _, _, z = util.split_list(outputs, [num_tokens, cond_nconsts, body_nconsts]) @@ -2274,7 +2274,7 @@ def _while_to_lojax(*hi_args, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts out_mut, lo_outs = split_list(all_outs, [pe.num_himuts_out(body_jaxpr)]) pe.apply_himut(body_jaxpr, [*hi_bconsts, *hi_carry], out_mut) return pe.raise_lo_outs(body_jaxpr.out_avals, lo_outs) -while_p.to_lojax = _while_to_lojax # type: ignore +while_p.to_lojax = _while_to_lojax def _insert_binders(jaxpr, n_after, vals): avals = _map(typeof, vals) @@ -2463,9 +2463,9 @@ def fori_loop(lower, upper, body_fun, init_val): "are statically known.") if lower_dtype != dtype: - lower = lax.convert_element_type(lower, dtype) # type: ignore + lower = lax.convert_element_type(lower, dtype) if upper_dtype != dtype: - upper = lax.convert_element_type(upper, dtype) # type: ignore + upper = lax.convert_element_type(upper, dtype) while_body_fun = _fori_body_fun(body_fun, body_fun_dbg) _, _, result = while_loop(_fori_cond_fun, while_body_fun, (lower, upper, init_val)) diff --git a/jax/_src/lax/convolution.py b/jax/_src/lax/convolution.py index 350b90f7adc7..9fc72b19aaec 100644 --- a/jax/_src/lax/convolution.py +++ b/jax/_src/lax/convolution.py @@ -791,7 +791,7 @@ def _conv_general_dilated_lower( return complex_conv(ctx, lhs, rhs) lhs_spec, rhs_spec, out_spec = dimension_numbers - dnums = hlo.ConvDimensionNumbers.get( # pyrefly: ignore[missing-attribute] + dnums = hlo.ConvDimensionNumbers.get( input_batch_dimension=lhs_spec[0], input_feature_dimension=lhs_spec[1], input_spatial_dimensions=list(lhs_spec[2:]), diff --git a/jax/_src/lax/fft.py b/jax/_src/lax/fft.py index 56364a0fb8c6..dd169e5b6452 100644 --- a/jax/_src/lax/fft.py +++ b/jax/_src/lax/fft.py @@ -130,7 +130,6 @@ def _fft_lowering(ctx, x, *, fft_type, fft_lengths): # TODO: https://github.com/openxla/stablehlo/issues/1366 raise NotImplementedError("Shape polymorphism for FFT with non-constant fft_length is not implemented for TPU and GPU") return [ - # pyrefly: ignore[missing-attribute] hlo.FftOp(x, hlo.FftTypeAttr.get(fft_type.name), mlir.dense_int_array(fft_lengths)).result ] diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 0cc0a8a4cdec..799e6e59a363 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1636,7 +1636,8 @@ def _convert_element_type( if new_dtype == old_dtype: if sharding is None: return operand - if isinstance(operand, core.Tracer) and operand.aval.sharding == sharding: + if (isinstance(operand, core.Tracer) and + operand.aval.sharding == sharding): # pyrefly: ignore[missing-attribute] return operand if sharding is not None or weak_type: raise NotImplementedError @@ -3584,7 +3585,7 @@ def full_like(x: ArrayLike | DuckTypedArray, return dtype._rules.full(fill_shape, fill_value, dtype) # type: ignore[union-attr] if sharding is None and shape is None and isinstance(x, core.Tracer): - sharding = x.aval.sharding + sharding = x.aval.sharding # pyrefly: ignore[missing-attribute] else: # If `x` has a sharding but no `_committed` attribute # (in case of ShapeDtypeStruct), default it to True. @@ -4618,9 +4619,9 @@ def _add_transpose(t, x, y): # api_test.py's CustomJVPTest.test_jaxpr_zeros. # assert ad.is_undefined_primal(x) and ad.is_undefined_primal(y) x_aval = x.aval if ad.is_undefined_primal(x) else core.typeof(x) - x_aval = x_aval.to_cotangent_aval() + x_aval = x_aval.to_ct_aval() y_aval = y.aval if ad.is_undefined_primal(y) else core.typeof(y) - y_aval = y_aval.to_cotangent_aval() + y_aval = y_aval.to_ct_aval() if type(t) is ad_util.Zero: return [ad_util.Zero(x_aval), ad_util.Zero(y_aval)] else: @@ -4720,8 +4721,8 @@ def _mul_unreduced_rule(out_sharding, x, y): lambda ydot, x, y: mul(x, ydot)) ad.defbilinear( mul_p, - lambda ct, x, y: _unbroadcast(x.aval.to_cotangent_aval(), mul(ct, y)), - lambda ct, x, y: _unbroadcast(y.aval.to_cotangent_aval(), mul(x, ct))) + lambda ct, x, y: _unbroadcast(x.aval.to_ct_aval(), mul(ct, y)), + lambda ct, x, y: _unbroadcast(y.aval.to_ct_aval(), mul(x, ct))) mlir.register_lowering(mul_p, partial(_nary_lower_hlo, hlo.multiply)) def _div_transpose_rule(cotangent, x, y): @@ -4895,7 +4896,7 @@ def _convert_element_type_transpose_rule(ct, operand, *, new_dtype, weak_type, assert ad.is_undefined_primal(operand) old_dtype = operand.aval.dtype old_weak_type = dtypes.is_weakly_typed(operand) - operand_ct_aval = operand.aval.to_cotangent_aval() + operand_ct_aval = operand.aval.to_ct_aval() if type(ct) is ad_util.Zero: return [ad_util.Zero(operand_ct_aval)] elif core.primal_dtype_to_tangent_dtype(old_dtype) == dtypes.float0: @@ -5447,7 +5448,7 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract))) unsorted_axes = list(x_batch) + x_kept + x_contract_sorted_by_y out_axes = np.argsort(unsorted_axes) - xs = x.aval.to_cotangent_aval().sharding + xs = x.aval.to_ct_aval().sharding inverse_spec = tuple(xs.spec[o] for o in unsorted_axes) ds = xs.update(spec=xs.spec.update(partitions=inverse_spec)) dot_general_out = dot_general(g, y, dims, precision=precision, @@ -6436,7 +6437,7 @@ def _broadcast_in_dim_typecheck_rule( def _broadcast_in_dim_transpose_rule(ct, operand, shape, broadcast_dimensions, sharding): - ct_aval = operand.aval.to_cotangent_aval() + ct_aval = operand.aval.to_ct_aval() if type(ct) is ad_util.Zero: return [ad_util.Zero(ct_aval)] if not isinstance(operand, ad.UndefinedPrimal): @@ -6738,7 +6739,7 @@ def _concatenate_transpose_rule(ct, *operands, dimension): operand_shapes = [o.aval.shape if ad.is_undefined_primal(o) else o.shape for o in operands] if type(ct) is ad_util.Zero: - return [ad_util.Zero(o.aval.to_cotangent_aval()) + return [ad_util.Zero(o.aval.to_ct_aval()) if ad.is_undefined_primal(o) else None for o in operands] else: return split(ct, tuple(shape[dimension] for shape in operand_shapes), @@ -6797,7 +6798,7 @@ def _split_weak_type_rule(operand, *, sizes, axis): def _split_transpose_rule(cotangents, operand, *, sizes, axis): assert ad.is_undefined_primal(operand) if all(type(t) is ad_util.Zero for t in cotangents): - return [ad_util.Zero(operand.aval.to_cotangent_aval())] + return [ad_util.Zero(operand.aval.to_ct_aval())] cotangents = [ct.instantiate() if type(ct) is ad_util.Zero else ct for ct in cotangents] return [concatenate(cotangents, dimension=axis)] @@ -7217,7 +7218,7 @@ def _reshape_dtype_rule(operand, *, new_sizes, dimensions, sharding): def _reshape_transpose_rule(ct, operand, *, new_sizes, dimensions, sharding): assert ad.is_undefined_primal(operand) - op_ct_aval = operand.aval.to_cotangent_aval() + op_ct_aval = operand.aval.to_ct_aval() if dimensions is None: return [reshape(ct, op_ct_aval.shape, out_sharding=op_ct_aval.sharding)] else: @@ -7692,7 +7693,7 @@ def _reduce_sum_transpose_rule(cotangent, operand, *, axes, out_sharding): broadcast_dimensions = tuple(np.delete(np.arange(len(input_shape)), axes)) result = broadcast_in_dim( cotangent, input_shape, broadcast_dimensions, - out_sharding=operand.aval.to_cotangent_aval().sharding) + out_sharding=operand.aval.to_ct_aval().sharding) assert result.shape == input_shape return [result] @@ -8228,9 +8229,12 @@ def _top_k_lower(ctx, operand, k, axis): out_values_aval, out_indices_aval, = ctx.avals_out results = mlir.custom_call( "stablehlo.dynamic_top_k", - result_types=[mlir.aval_to_ir_type(out_values_aval), - mlir.aval_to_ir_type(out_indices_aval)], - operands=[operand, k_value]).results + result_types=mlir.flatten_ir_types([ + mlir.aval_to_ir_type(out_values_aval), + mlir.aval_to_ir_type(out_indices_aval) + ]), + operands=[operand, k_value], + ).results # Move last dimension back into place if perm is not None: @@ -8419,10 +8423,13 @@ def _rng_bit_generator_lowering( mlir.eval_dynamic_shape(ctx, out_vals_aval.shape)) out_key, out_vals = mlir.custom_call( "stablehlo.dynamic_rng_bit_generator", - result_types=[key.type, - mlir.aval_to_ir_type(core.ShapedArray(shape, rbg_dtype))], - operands=[key, output_shape], - extra_attributes=dict(rng_algorithm=algorithm_attr)).results + result_types=mlir.flatten_ir_types([ + key.type, + mlir.aval_to_ir_type(core.ShapedArray(shape, rbg_dtype)) + ]), + operands=mlir.flatten_ir_values([key, output_shape]), + extra_attributes=dict(rng_algorithm=algorithm_attr), + ).results else: out_key, out_vals = hlo.RngBitGeneratorOp( key.type, diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 16f5a46fdedc..479380e7dbb7 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -2423,7 +2423,7 @@ def _triangular_solve_lowering( out = hlo.triangular_solve(a, b, ir.BoolAttr.get(left_side), ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal), - hlo.TransposeAttr.get(transpose)) # pyrefly: ignore[missing-attribute] + hlo.TransposeAttr.get(transpose)) return [mlir.lower_with_sharding_in_types(ctx, out, out_aval)] @@ -2460,7 +2460,6 @@ def _triangular_solve_cpu_lower( return [hlo.triangular_solve(a, b, ir.BoolAttr.get(left_side), ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal), - # pyrefly: ignore[missing-attribute] hlo.TransposeAttr.get(transpose))] triangular_solve_p = linalg_primitive( diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index c9d2cfe7fb94..b97a22b630af 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -50,8 +50,8 @@ unzip2) import numpy as np -unsafe_map, map = map, safe_map # type: ignore -unsafe_zip, zip = zip, safe_zip # type: ignore +unsafe_map, map = map, safe_map +unsafe_zip, zip = zip, safe_zip ### parallel traceables @@ -1477,7 +1477,7 @@ def _ragged_all_to_all_lowering( if not all(split_count == len(g) for g in replica_groups): raise ValueError('Replica groups must be equally sized') - ragged_all_to_all_attrs = { + ragged_all_to_all_attrs: dict[str, ir.Attribute] = { "replica_groups": _replica_groups_hlo(replica_groups) } is_spmd = isinstance( diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 6686b0f379ee..fe6440487610 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -173,7 +173,7 @@ def dynamic_slice( """ start_indices = _dynamic_slice_indices( operand, start_indices, allow_negative_indices) - sizes = core.canonicalize_shape(slice_sizes) # type: ignore + sizes = core.canonicalize_shape(slice_sizes) operand, *start_indices = core.standard_insert_pvary(operand, *start_indices) return dynamic_slice_p.bind(operand, *start_indices, slice_sizes=tuple(sizes)) @@ -1702,8 +1702,8 @@ def _dynamic_update_slice_transpose_rule(t, operand, update, *start_indices): assert all(not ad.is_undefined_primal(x) for x in start_indices) update_shape = (update.aval.shape if ad.is_undefined_primal(update) else update.shape) - operand_ct_aval = operand.aval.to_cotangent_aval() - update_ct_aval = update.aval.to_cotangent_aval() + operand_ct_aval = operand.aval.to_ct_aval() + update_ct_aval = update.aval.to_ct_aval() if type(t) is ad_util.Zero: operand_t = (ad_util.Zero(operand_ct_aval) if ad.is_undefined_primal(operand) else None) diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 931684ef8523..43a714f46a2d 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -487,7 +487,7 @@ def reducer_body(reducer: ir.Block) -> Sequence[ir.Value]: if jaxpr.effects: raise NotImplementedError('Cannot lower effectful `reduce_window`.') out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, jaxpr, ctx.name_stack, - mlir.TokenSet(), consts, *reducer.arguments, # type: ignore[misc] + mlir.TokenSet(), consts, *reducer.arguments, dim_var_values=ctx.dim_var_values, const_lowering=ctx.const_lowering, outer_traceback=ctx.traceback) return mlir.flatten_ir_values(out_nodes) @@ -997,7 +997,7 @@ def snd(t, t_aval): def reducer_body(reducer: ir.Block) -> Sequence[ir.Value]: x: ir.Value y: ir.Value - x, y = reducer.arguments # type: ignore + x, y = reducer.arguments assert select_prim is lax.ge_p or select_prim is lax.le_p cmp_op = "GE" if select_prim is lax.ge_p else "LE" out = hlo.SelectOp(mlir.compare_hlo(fst(x), fst(y), cmp_op), x, y) diff --git a/jax/_src/lax_reference.py b/jax/_src/lax_reference.py index ab035fecc31a..92b6d5e5d39f 100644 --- a/jax/_src/lax_reference.py +++ b/jax/_src/lax_reference.py @@ -241,8 +241,8 @@ def conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation, def dot_general(lhs, rhs, dimension_numbers): (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers new_id = itertools.count() - lhs_axis_ids = [next(new_id) for _ in lhs.shape] - rhs_axis_ids = [next(new_id) for _ in rhs.shape] + lhs_axis_ids: list[int | None] = [next(new_id) for _ in lhs.shape] + rhs_axis_ids: list[int | None] = [next(new_id) for _ in rhs.shape] lhs_out_axis_ids = lhs_axis_ids[:] rhs_out_axis_ids = rhs_axis_ids[:] @@ -267,8 +267,9 @@ def dot_general(lhs, rhs, dimension_numbers): batch_ids + lhs_out_axis_ids + rhs_out_axis_ids) assert lhs.dtype == rhs.dtype dtype = np.float32 if lhs.dtype == dtypes.bfloat16 else None - out = np.einsum(lhs, lhs_axis_ids, rhs, rhs_axis_ids, out_axis_ids, - dtype=dtype) + out = np.einsum( # pyrefly: ignore[no-matching-overload] + lhs, lhs_axis_ids, rhs, rhs_axis_ids, out_axis_ids, dtype=dtype + ) return out.astype(dtypes.bfloat16) if lhs.dtype == dtypes.bfloat16 else out def ragged_dot( diff --git a/jax/_src/layout.py b/jax/_src/layout.py index 2d3a23941b76..8e9bf8eb4f42 100644 --- a/jax/_src/layout.py +++ b/jax/_src/layout.py @@ -49,7 +49,7 @@ def __init__(self, major_to_minor: tuple[int, ...], def from_pjrt_layout(pjrt_layout: xc.PjRtLayout): xla_layout = pjrt_layout._xla_layout() return Layout(xla_layout.minor_to_major()[::-1], # pytype: disable=wrong-arg-types - xla_layout.tiling(), # type: ignore[arg-type] + xla_layout.tiling(), # pyrefly: ignore[bad-argument-type] xla_layout.element_size_in_bits()) def __repr__(self): diff --git a/jax/_src/lru_cache.py b/jax/_src/lru_cache.py index 7a09dc92a17b..c08ffb10fcf9 100644 --- a/jax/_src/lru_cache.py +++ b/jax/_src/lru_cache.py @@ -22,7 +22,7 @@ filelock: Any | None = None try: - import filelock # type: ignore[no-redef] + import filelock # pyrefly: ignore[missing-import] except ImportError: pass diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index f0210d5d9b82..8ebf84d625f7 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -251,6 +251,7 @@ class Mesh(BaseMesh, contextlib.ContextDecorator): devices: np.ndarray axis_names: tuple[MeshAxisName, ...] + _size: int def __new__(cls, devices: np.ndarray | Sequence[xc.Device], axis_names: str | Sequence[MeshAxisName], @@ -354,7 +355,7 @@ def shape(self): for name, size in safe_zip(self.axis_names, self.devices.shape)) @functools.cached_property - def shape_tuple(self): + def shape_tuple(self): # pyrefly: ignore[bad-override] return tuple( (name, size) for name, size in safe_zip(self.axis_names, self.devices.shape)) @@ -538,7 +539,7 @@ def shape(self): return collections.OrderedDict(self.shape_tuple) @functools.cached_property - def shape_tuple(self): + def shape_tuple(self): # pyrefly: ignore[bad-override] return tuple( (name, size) for name, size in safe_zip(self.axis_names, self.axis_sizes)) diff --git a/jax/_src/named_sharding.py b/jax/_src/named_sharding.py index 5581e6fd5f2e..d1215c561596 100644 --- a/jax/_src/named_sharding.py +++ b/jax/_src/named_sharding.py @@ -18,7 +18,7 @@ import collections import dataclasses import functools -from typing import Any, Union +from typing import Any, Union, overload from jax._src.util import use_cpp_class, cache, use_cpp_method from jax._src.lib import xla_client as xc @@ -204,7 +204,7 @@ def is_fully_addressable(self) -> bool: raise ValueError('is_fully_addressable is not implemented for ' '`jax.sharding.AbstractMesh`.') # return False if addressable_device_list is empty. - return self._internal_device_list.is_fully_addressable # type: ignore + return self._internal_device_list.is_fully_addressable @property def _is_concrete(self) -> bool: @@ -228,7 +228,7 @@ def is_fully_replicated(self) -> bool: array_mapping = get_array_mapping(self.spec) mesh_shape = self.mesh.shape num_partitions = 1 - for name in array_mapping: # type: ignore + for name in array_mapping: num_partitions *= mesh_shape[name] return num_partitions == 1 @@ -286,6 +286,18 @@ def flatten_spec(spec): return out +@overload +def get_array_mapping(axis_resources: PartitionSpec) -> ArrayMapping: + ... + +@overload +def get_array_mapping(axis_resources: AUTO) -> AUTO: + ... + +@overload +def get_array_mapping(axis_resources: UnspecifiedValue) -> UnspecifiedValue: + ... + def get_array_mapping( axis_resources: PartitionSpec | AUTO | UnspecifiedValue ) -> ArrayMappingOrAutoOrUnspecified: @@ -404,7 +416,7 @@ def named_sharding_to_xla_hlo_sharding( replicated_mesh_axes = [] for i, (axis_name, axis_val) in enumerate(mesh_shape.items()): - if axis_name not in array_mapping: # type: ignore + if axis_name not in array_mapping: replicated_mesh_axes.append((i, axis_val)) if len(replicated_mesh_axes) == len(mesh_shape) and not special_axes: @@ -412,7 +424,7 @@ def named_sharding_to_xla_hlo_sharding( mesh_permutation = [] new_mesh_shape = [1] * num_dimensions - for name, pos in sorted(array_mapping.items(), key=lambda x: x[1]): # type: ignore + for name, pos in sorted(array_mapping.items(), key=lambda x: x[1]): new_mesh_shape[pos] *= mesh_shape[name] mesh_permutation.append(mesh_axis_pos[name]) @@ -468,11 +480,11 @@ def array_mapping_to_axis_resources(array_mapping: ArrayMapping): reverse_map[index].append(axis) if index > max_index: max_index = index - partitions = [] + partitions: list[MeshAxisName | None] = [] for i in range(max_index + 1): axis = reverse_map[i] if axis: - partitions.append(axis[0] if len(axis) == 1 else tuple(axis)) + partitions.append(axis[0] if len(axis) == 1 else tuple(axis)) # pytype: disable=container-type-mismatch else: partitions.append(None) return PartitionSpec(*partitions) diff --git a/jax/_src/numpy/array_constructors.py b/jax/_src/numpy/array_constructors.py index 43c6f647249f..e92054a91f89 100644 --- a/jax/_src/numpy/array_constructors.py +++ b/jax/_src/numpy/array_constructors.py @@ -191,7 +191,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, weak_type = dtype is None and dtypes.is_weakly_typed(object) if device is None and out_sharding is None and isinstance(object, core.Tracer): - sharding = object.aval.sharding + sharding = object.aval.sharding # pyrefly: ignore[missing-attribute] sharding = None if sharding.mesh.empty else sharding else: sharding = util.choose_device_or_out_sharding(device, out_sharding, "jnp.array") diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 8f10f236bb7d..a08facb1acce 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -13,7 +13,6 @@ # limitations under the License. # pytype: skip-file -# mypy: disable-error-code=has-type """Define methods which are dynamically added to JAX's Arrays and Tracers. This is done dynamically in order to avoid circular imports. @@ -26,7 +25,7 @@ import abc from functools import wraps import math -from typing import Any +from typing import Any, cast from collections.abc import Callable, Sequence import numpy as np @@ -589,6 +588,11 @@ def _notimplemented_flat(self): raise NotImplementedError("JAX Arrays do not implement the arr.flat property: " "consider arr.flatten() instead.") +# TODO(jakevdp): make _accepted_binop_types match the ArrayLike union. Currently +# ArrayLike includes np.number, while here we are more permissive and include +# np.generic: this is required because as of v0.5.X, ml_dtypes types are subclasses +# of np.generic rather than of np.number. Making these match will allow removal of +# cast() calls in the operator definitions below. _accepted_binop_types = ( int, float, @@ -600,21 +604,345 @@ def _notimplemented_flat(self): ) _rejected_binop_types = (list, tuple, set, dict) -def _defer_to_unrecognized_arg(opchar, binary_op, swap=False): - # Ensure that other array types have the chance to override arithmetic. - def deferring_binary_op(self, other): - if hasattr(other, '__jax_array__'): - other = other.__jax_array__() - args = (other, self) if swap else (self, other) - if isinstance(other, _accepted_binop_types): - return binary_op(*args) - # Note: don't use isinstance here, because we don't want to raise for - # subclasses, e.g. NamedTuple objects that may override operators. - if type(other) in _rejected_binop_types: - raise TypeError(f"unsupported operand type(s) for {opchar}: " - f"{type(args[0]).__name__!r} and {type(args[1]).__name__!r}") - return NotImplemented - return deferring_binary_op +def _operator_eq(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.equal(self, cast(ArrayLike, other)) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for ==: " + f"{type(self).__name__!r} and {type(other).__name__!r}") + return NotImplemented + +def _operator_ne(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.not_equal(self, cast(ArrayLike, other)) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for !=: " + f"{type(self).__name__!r} and {type(other).__name__!r}") + return NotImplemented + +def _operator_lt(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.less(self, cast(ArrayLike, other)) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for <: " + f"{type(self).__name__!r} and {type(other).__name__!r}") + return NotImplemented + +def _operator_le(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.less_equal(self, cast(ArrayLike, other)) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for <=: " + f"{type(self).__name__!r} and {type(other).__name__!r}") + return NotImplemented + +def _operator_gt(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.greater(self, cast(ArrayLike, other)) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for >: " + f"{type(self).__name__!r} and {type(other).__name__!r}") + return NotImplemented + +def _operator_ge(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.greater_equal(self, cast(ArrayLike, other)) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for >=: " + f"{type(self).__name__!r} and {type(other).__name__!r}") + return NotImplemented + +def _operator_add(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.add(self, cast(ArrayLike, other)) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for +: " + f"{type(self).__name__!r} and {type(other).__name__!r}") + return NotImplemented + +def _operator_radd(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.add(cast(ArrayLike, other), self) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for +: " + f"{type(other).__name__!r} and {type(self).__name__!r}") + return NotImplemented + +def _operator_sub(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.subtract(self, cast(ArrayLike, other)) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for -: " + f"{type(self).__name__!r} and {type(other).__name__!r}") + return NotImplemented + +def _operator_rsub(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.subtract(cast(ArrayLike, other), self) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for -: " + f"{type(other).__name__!r} and {type(self).__name__!r}") + return NotImplemented + +def _operator_mul(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.multiply(self, cast(ArrayLike, other)) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for *: " + f"{type(self).__name__!r} and {type(other).__name__!r}") + return NotImplemented + +def _operator_rmul(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.multiply(cast(ArrayLike, other), self) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for *: " + f"{type(other).__name__!r} and {type(self).__name__!r}") + return NotImplemented + +def _operator_truediv(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.true_divide(self, cast(ArrayLike, other)) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for /: " + f"{type(self).__name__!r} and {type(other).__name__!r}") + return NotImplemented + +def _operator_rtruediv(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.true_divide(cast(ArrayLike, other), self) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for /: " + f"{type(other).__name__!r} and {type(self).__name__!r}") + return NotImplemented + +def _operator_floordiv(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.floor_divide(self, cast(ArrayLike, other)) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for //: " + f"{type(self).__name__!r} and {type(other).__name__!r}") + return NotImplemented + +def _operator_rfloordiv(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.floor_divide(cast(ArrayLike, other), self) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for //: " + f"{type(other).__name__!r} and {type(self).__name__!r}") + return NotImplemented + +def _operator_divmod(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.divmod(self, cast(ArrayLike, other)) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for divmod: " + f"{type(self).__name__!r} and {type(other).__name__!r}") + return NotImplemented + +def _operator_rdivmod(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.divmod(cast(ArrayLike, other), self) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for divmod: " + f"{type(other).__name__!r} and {type(self).__name__!r}") + return NotImplemented + +def _operator_mod(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.mod(self, cast(ArrayLike, other)) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for %: " + f"{type(self).__name__!r} and {type(other).__name__!r}") + return NotImplemented + +def _operator_rmod(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.mod(cast(ArrayLike, other), self) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for %: " + f"{type(other).__name__!r} and {type(self).__name__!r}") + return NotImplemented + +def _operator_pow(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.power(self, cast(ArrayLike, other)) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for **: " + f"{type(self).__name__!r} and {type(other).__name__!r}") + return NotImplemented + +def _operator_rpow(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.power(cast(ArrayLike, other), self) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for **: " + f"{type(other).__name__!r} and {type(self).__name__!r}") + return NotImplemented + +def _operator_matmul(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return tensor_contractions.matmul(self, cast(ArrayLike, other)) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for @: " + f"{type(self).__name__!r} and {type(other).__name__!r}") + return NotImplemented + +def _operator_rmatmul(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return tensor_contractions.matmul(cast(ArrayLike, other), self) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for @: " + f"{type(other).__name__!r} and {type(self).__name__!r}") + return NotImplemented + +def _operator_and(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.bitwise_and(self, cast(ArrayLike, other)) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for &: " + f"{type(self).__name__!r} and {type(other).__name__!r}") + return NotImplemented + +def _operator_rand(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.bitwise_and(cast(ArrayLike, other), self) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for &: " + f"{type(other).__name__!r} and {type(self).__name__!r}") + return NotImplemented + +def _operator_or(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.bitwise_or(self, cast(ArrayLike, other)) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for |: " + f"{type(self).__name__!r} and {type(other).__name__!r}") + return NotImplemented + +def _operator_ror(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.bitwise_or(cast(ArrayLike, other), self) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for |: " + f"{type(other).__name__!r} and {type(self).__name__!r}") + return NotImplemented + +def _operator_xor(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.bitwise_xor(self, cast(ArrayLike, other)) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for ^: " + f"{type(self).__name__!r} and {type(other).__name__!r}") + return NotImplemented + +def _operator_rxor(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.bitwise_xor(cast(ArrayLike, other), self) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for ^: " + f"{type(other).__name__!r} and {type(self).__name__!r}") + return NotImplemented + +def _operator_lshift(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.left_shift(self, cast(ArrayLike, other)) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for <<: " + f"{type(self).__name__!r} and {type(other).__name__!r}") + return NotImplemented + +def _operator_rshift(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.right_shift(self, cast(ArrayLike, other)) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for >>: " + f"{type(self).__name__!r} and {type(other).__name__!r}") + return NotImplemented + +def _operator_rlshift(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.left_shift(cast(ArrayLike, other), self) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for <<: " + f"{type(other).__name__!r} and {type(self).__name__!r}") + return NotImplemented + +def _operator_rrshift(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + if isinstance(other, _accepted_binop_types): + return ufuncs.right_shift(cast(ArrayLike, other), self) + if type(other) in _rejected_binop_types: + raise TypeError(f"unsupported operand type(s) for >>: " + f"{type(other).__name__!r} and {type(self).__name__!r}") + return NotImplemented def _unimplemented_setitem(self, i, x): msg = ("JAX arrays are immutable and do not support in-place item assignment." @@ -1022,44 +1350,44 @@ def max(self, values: ArrayLike, *, "setitem": _unimplemented_setitem, "copy": _copy, "deepcopy": _deepcopy, - "neg": lambda self: ufuncs.negative(self), - "pos": lambda self: ufuncs.positive(self), - "eq": _defer_to_unrecognized_arg("==", ufuncs.equal), - "ne": _defer_to_unrecognized_arg("!=", ufuncs.not_equal), - "lt": _defer_to_unrecognized_arg("<", ufuncs.less), - "le": _defer_to_unrecognized_arg("<=", ufuncs.less_equal), - "gt": _defer_to_unrecognized_arg(">", ufuncs.greater), - "ge": _defer_to_unrecognized_arg(">=", ufuncs.greater_equal), - "abs": lambda self: ufuncs.abs(self), - "add": _defer_to_unrecognized_arg("+", ufuncs.add), - "radd": _defer_to_unrecognized_arg("+", ufuncs.add, swap=True), - "sub": _defer_to_unrecognized_arg("-", ufuncs.subtract), - "rsub": _defer_to_unrecognized_arg("-", ufuncs.subtract, swap=True), - "mul": _defer_to_unrecognized_arg("*", ufuncs.multiply), - "rmul": _defer_to_unrecognized_arg("*", ufuncs.multiply, swap=True), - "truediv": _defer_to_unrecognized_arg("/", ufuncs.true_divide), - "rtruediv": _defer_to_unrecognized_arg("/", ufuncs.true_divide, swap=True), - "floordiv": _defer_to_unrecognized_arg("//", ufuncs.floor_divide), - "rfloordiv": _defer_to_unrecognized_arg("//", ufuncs.floor_divide, swap=True), - "divmod": _defer_to_unrecognized_arg("divmod", ufuncs.divmod), - "rdivmod": _defer_to_unrecognized_arg("divmod", ufuncs.divmod, swap=True), - "mod": _defer_to_unrecognized_arg("%", ufuncs.mod), - "rmod": _defer_to_unrecognized_arg("%", ufuncs.mod, swap=True), - "pow": _defer_to_unrecognized_arg("**", ufuncs.power), - "rpow": _defer_to_unrecognized_arg("**", ufuncs.power, swap=True), - "matmul": _defer_to_unrecognized_arg("@", tensor_contractions.matmul), - "rmatmul": _defer_to_unrecognized_arg("@", tensor_contractions.matmul, swap=True), - "and": _defer_to_unrecognized_arg("&", ufuncs.bitwise_and), - "rand": _defer_to_unrecognized_arg("&", ufuncs.bitwise_and, swap=True), - "or": _defer_to_unrecognized_arg("|", ufuncs.bitwise_or), - "ror": _defer_to_unrecognized_arg("|", ufuncs.bitwise_or, swap=True), - "xor": _defer_to_unrecognized_arg("^", ufuncs.bitwise_xor), - "rxor": _defer_to_unrecognized_arg("^", ufuncs.bitwise_xor, swap=True), - "invert": lambda self: ufuncs.bitwise_not(self), - "lshift": _defer_to_unrecognized_arg("<<", ufuncs.left_shift), - "rshift": _defer_to_unrecognized_arg(">>", ufuncs.right_shift), - "rlshift": _defer_to_unrecognized_arg("<<", ufuncs.left_shift, swap=True), - "rrshift": _defer_to_unrecognized_arg(">>", ufuncs.right_shift, swap=True), + "neg": ufuncs.negative._func, + "pos": ufuncs.positive, + "abs": ufuncs.abs, + "invert": ufuncs.invert, + "eq": _operator_eq, + "ne": _operator_ne, + "lt": _operator_lt, + "le": _operator_le, + "gt": _operator_gt, + "ge": _operator_ge, + "add": _operator_add, + "radd": _operator_radd, + "sub": _operator_sub, + "rsub": _operator_rsub, + "mul": _operator_mul, + "rmul": _operator_rmul, + "truediv": _operator_truediv, + "rtruediv": _operator_rtruediv, + "floordiv": _operator_floordiv, + "rfloordiv": _operator_rfloordiv, + "divmod": _operator_divmod, + "rdivmod": _operator_rdivmod, + "mod": _operator_mod, + "rmod": _operator_rmod, + "pow": _operator_pow, + "rpow": _operator_rpow, + "matmul": _operator_matmul, + "rmatmul": _operator_rmatmul, + "and": _operator_and, + "rand": _operator_rand, + "or": _operator_or, + "ror": _operator_ror, + "xor": _operator_xor, + "rxor": _operator_rxor, + "lshift": _operator_lshift, + "rshift": _operator_rshift, + "rlshift": _operator_rlshift, + "rrshift": _operator_rrshift, "round": _operator_round, } diff --git a/jax/_src/numpy/einsum.py b/jax/_src/numpy/einsum.py index b8f72081df62..c3a4c244e72f 100644 --- a/jax/_src/numpy/einsum.py +++ b/jax/_src/numpy/einsum.py @@ -299,7 +299,7 @@ def einsum( for d in np.shape(op) if not core.is_constant_dim(d) } if not non_constant_dim_types: - contract_path = opt_einsum.contract_path + contract_path: Any = opt_einsum.contract_path else: ty = next(iter(non_constant_dim_types)) contract_path = _poly_einsum_handlers.get(ty, _default_poly_einsum_handler) @@ -339,7 +339,7 @@ def einsum( # Enable other modules to override einsum_contact_path. # Indexed by the type of the non constant dimension -_poly_einsum_handlers = {} # type: ignore +_poly_einsum_handlers = {} def _default_poly_einsum_handler(*operands, **kwargs): dummy = collections.namedtuple('dummy', ['shape', 'dtype']) @@ -419,8 +419,11 @@ def einsum_path( .. _opt_einsum: https://github.com/dgasmith/opt_einsum """ if isinstance(optimize, bool): - optimize = 'optimal' if optimize else Unoptimized() - return opt_einsum.contract_path(subscripts, *operands, optimize=optimize) + optimize2: Any = 'optimal' if optimize else Unoptimized() + else: + optimize2 = optimize + # pyrefly: ignore[no-matching-overload] + return opt_einsum.contract_path(subscripts, *operands, optimize=optimize2) def _removechars(s, chars): return s.translate(str.maketrans(dict.fromkeys(chars))) @@ -569,7 +572,7 @@ def filter_singleton_dims(operand, names, other_shape, other_names): out_sharding = (_get_inverse_sharding(out_sharding, names, result_names) if out_sharding is not None and names != result_names else out_sharding) - dot_out_sharding = ({} if out_sharding is None else # type: ignore + dot_out_sharding = ({} if out_sharding is None else {'out_sharding': out_sharding}) operand = _dot_general(lhs, rhs, dimension_numbers, precision, preferred_element_type=preferred_element_type, diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 1b06de17eb58..bf7e55d5ab98 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3404,7 +3404,7 @@ def clip( if min is not None: arr = ufuncs.maximum(min, arr) if max is not None: - arr = ufuncs.minimum(max, arr) # type: ignore + arr = ufuncs.minimum(max, arr) return asarray(arr) @@ -6422,13 +6422,13 @@ def _auto_repeat(fun, a, repeats, axis, total_repeat_length, out_sharding): return auto_axes(partial(fun, repeats, axis=axis, total_repeat_length=total_repeat_length), out_sharding=out_sharding, - axes=out_sharding.mesh.explicit_axes # type: ignore + axes=out_sharding.mesh.explicit_axes )(a) else: return auto_axes( partial(fun, axis=axis, total_repeat_length=total_repeat_length), out_sharding=out_sharding, - axes=out_sharding.mesh.explicit_axes # type: ignore + axes=out_sharding.mesh.explicit_axes )(repeats, a) def _repeat(repeats, arr, *, axis: int, @@ -6454,7 +6454,7 @@ def _repeat(repeats, arr, *, axis: int, axis = _canonicalize_axis(axis, len(input_shape)) aux_axis = axis + 1 aux_shape: list[DimSize] = list(input_shape) - aux_shape.insert(aux_axis, operator.index(repeats) if core.is_constant_dim(repeats) else repeats) # type: ignore + aux_shape.insert(aux_axis, operator.index(repeats) if core.is_constant_dim(repeats) else repeats) arr = lax.broadcast_in_dim( arr, aux_shape, [i for i in range(len(aux_shape)) if i != aux_axis]) result_shape: list[DimSize] = list(input_shape) diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index eb522aa56125..3bd966e2ca23 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -403,7 +403,7 @@ def matrix_power(a: ArrayLike, n: int) -> Array: z: Array | None = None result: Array | None = None while n > 0: - z = arr if z is None else (z @ z) # type: ignore[operator] + z = arr if z is None else (z @ z) n, bit = divmod(n, 2) if bit: result = z if result is None else (result @ z) diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index da054665b3b0..ca9857b62099 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -446,7 +446,7 @@ def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array: del p, x shape = lax.broadcast_shapes(p_arr.shape[1:], x_arr.shape) y = lax.full_like(x_arr, 0, shape=shape, dtype=x_arr.dtype) - y, _ = control_flow.scan(lambda y, p: (y * x_arr + p, None), y, p_arr, unroll=unroll) # type: ignore[misc] + y, _ = control_flow.scan(lambda y, p: (y * x_arr + p, None), y, p_arr, unroll=unroll) return y diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 53a78cf6c8ac..4ce903e3f0e3 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -91,7 +91,7 @@ def promote_dtypes_inexact(*args: ArrayLike) -> list[Array]: Promotes arguments to an inexact type.""" to_dtype, weak_type = dtypes.lattice_result_type(*args) - to_dtype_inexact = dtypes.to_inexact_dtype(to_dtype) # type: ignore[arg-type] + to_dtype_inexact = dtypes.to_inexact_dtype(to_dtype) return [lax._convert_element_type(x, to_dtype_inexact, weak_type) for x in args] @@ -408,7 +408,7 @@ def shape(a: ArrayLike | SupportsShape) -> tuple[int, ...]: if hasattr(a, "__jax_array__"): a = a.__jax_array__() # NumPy dispatches to a.shape if available. - return np.shape(a) # type: ignore[arg-type] + return np.shape(a) @export diff --git a/jax/_src/ops/scatter.py b/jax/_src/ops/scatter.py index f120d386b5fd..ecd885c4df00 100644 --- a/jax/_src/ops/scatter.py +++ b/jax/_src/ops/scatter.py @@ -83,7 +83,7 @@ def _scatter_update(x: ArrayLike, idx: Index | tuple[Index, ...], normalize_indices=normalize_indices) if out_sharding is not None: return auto_axes(internal_scatter, out_sharding=out_sharding, - axes=out_sharding.mesh.explicit_axes # type: ignore + axes=out_sharding.mesh.explicit_axes )(x, y, dynamic_idx) return internal_scatter(x, y, tuple(dynamic_idx)) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 8cd166bdbed7..4bdd1fb979ba 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -1171,9 +1171,9 @@ def get_grid_mapping( debug: bool = False, ) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]: if dynamic_shapes_export_enabled(): - dim_check : Any = jax_core.is_dim + dim_check: Any = jax_core.is_dim else: - dim_check : Any = jax_core.is_constant_dim # type: ignore[no-redef] + dim_check: Any = jax_core.is_constant_dim assert all(i is None or dim_check(i) for i in grid_spec.grid) grid_mapping_grid = tuple( dynamic_grid_dim if ( @@ -1235,7 +1235,7 @@ def get_grid_mapping( _convert_block_spec_to_block_mapping, index_map_avals=index_map_avals, index_map_tree=index_map_tree, - grid=grid_mapping_grid, # type: ignore[arg-type] + grid=grid_mapping_grid, vmapped_dims=(), debug=debug, ), @@ -1258,7 +1258,7 @@ def get_grid_mapping( _convert_block_spec_to_block_mapping, index_map_avals=index_map_avals, index_map_tree=index_map_tree, - grid=grid_mapping_grid, # type: ignore[arg-type] + grid=grid_mapping_grid, vmapped_dims=(), debug=debug, ), @@ -1295,9 +1295,9 @@ def get_grid_mapping( def unzip_dynamic_grid_bounds( grid_spec: GridSpec) -> tuple[GridSpec, tuple[Any, ...]]: if dynamic_shapes_export_enabled(): - new_grid : Any = grid_spec.grid + new_grid: Any = grid_spec.grid else: - new_grid : Any = tuple(d if isinstance(d, int) else None for d in grid_spec.grid) # type: ignore[no-redef] + new_grid: Any = tuple(d if isinstance(d, int) else None for d in grid_spec.grid) dynamic_bounds = tuple(d for d in grid_spec.grid if not isinstance(d, int)) # We can't use dataclasses.replace, because our fields are incompatible # with __init__'s signature. diff --git a/jax/_src/pallas/fuser/block_spec.py b/jax/_src/pallas/fuser/block_spec.py index d19b3cf1e951..7241c5e80046 100644 --- a/jax/_src/pallas/fuser/block_spec.py +++ b/jax/_src/pallas/fuser/block_spec.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# mypy: ignore-errors # pyrefly: ignore-errors # TODO(sharadmv): Enable type checking. diff --git a/jax/_src/pallas/fuser/custom_fusion_lib.py b/jax/_src/pallas/fuser/custom_fusion_lib.py index 9b705d86260d..8e1cc3b094c6 100644 --- a/jax/_src/pallas/fuser/custom_fusion_lib.py +++ b/jax/_src/pallas/fuser/custom_fusion_lib.py @@ -226,7 +226,7 @@ def _custom_fusion_mosaic_lowering_rule( lowering_context, pallas_jaxpr, *pallas_consts, *args) -@block_spec_lib.register_pull_block_spec_rule(custom_fusion_p) # type: ignore[arg-type] +@block_spec_lib.register_pull_block_spec_rule(custom_fusion_p) def _custom_fusion_pull_block_spec_rule( ctx : block_spec_lib.PullRuleContext, out_block_transforms : tuple[block_spec_lib.BlockIndexTransform, ...], @@ -238,7 +238,7 @@ def _custom_fusion_pull_block_spec_rule( return pull_block_spec_rule(out_block_transforms) -@block_spec_lib.register_push_block_spec_rule(custom_fusion_p) # type: ignore[arg-type] +@block_spec_lib.register_push_block_spec_rule(custom_fusion_p) def _custom_fusion_push_block_spec_rule( ctx : block_spec_lib.PushRuleContext, *block_specs : pallas_core.BlockSpec, @@ -250,7 +250,7 @@ def _custom_fusion_push_block_spec_rule( return push_block_spec_rule(block_specs) -@block_spec_lib.register_usage_rule(custom_fusion_p) # type: ignore[arg-type] +@block_spec_lib.register_usage_rule(custom_fusion_p) def _custom_fusion_usage_rule( ctx : block_spec_lib.UsageRuleContext, used_out: Sequence[set[block_spec_lib.Usage]], diff --git a/jax/_src/pallas/fuser/fusible_dtype.py b/jax/_src/pallas/fuser/fusible_dtype.py index 8a1684bb0466..f4d2ac6a22ca 100644 --- a/jax/_src/pallas/fuser/fusible_dtype.py +++ b/jax/_src/pallas/fuser/fusible_dtype.py @@ -43,7 +43,6 @@ from jax._src.util import foreach # TODO(sharadmv): Enable type checking. -# mypy: ignore-errors map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip diff --git a/jax/_src/pallas/helpers.py b/jax/_src/pallas/helpers.py index 496e43e6fc0d..3b6af7f7214d 100644 --- a/jax/_src/pallas/helpers.py +++ b/jax/_src/pallas/helpers.py @@ -256,9 +256,9 @@ def kernel(in_ref, out_ref): name=name, metadata=metadata) if isinstance(body, api.NotSpecified): - return lambda fun: _make_kernel(fun, **kwds) # type: ignore[arg-type] + return lambda fun: _make_kernel(fun, **kwds) else: - return _make_kernel(body, **kwds) # type: ignore[arg-type] + return _make_kernel(body, **kwds) def with_scoped( diff --git a/jax/_src/pallas/hlo_interpreter.py b/jax/_src/pallas/hlo_interpreter.py index 643879bc9a44..7180468c69f0 100644 --- a/jax/_src/pallas/hlo_interpreter.py +++ b/jax/_src/pallas/hlo_interpreter.py @@ -97,7 +97,7 @@ def _dynamic_slice( output = slicing.dynamic_slice(value, start_idx, slice_sizes=block_shape) squeeze_dims = tuple(np.arange(len(is_squeeze))[np.array(is_squeeze, dtype=np.bool_)]) - return lax.squeeze(output, squeeze_dims) # type: ignore[arg-type] + return lax.squeeze(output, squeeze_dims) def _dynamic_update_slice(start_idx, block_shape, value, update, is_squeeze): diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 603d1a8f6e15..e0ccf6c65355 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -13,6 +13,7 @@ # limitations under the License. """Contains TPU-specific Pallas abstractions.""" + from __future__ import annotations import collections @@ -54,6 +55,7 @@ class GridDimensionSemantics(enum.Enum): SUBCORE_PARALLEL = "subcore_parallel" ARBITRARY = "arbitrary" + PARALLEL = GridDimensionSemantics.PARALLEL CORE_PARALLEL = GridDimensionSemantics.CORE_PARALLEL SUBCORE_PARALLEL = GridDimensionSemantics.SUBCORE_PARALLEL @@ -212,12 +214,15 @@ def __getattr__(self, name): return super().__getattr__(name) # type: ignore -class dma_semaphore(pallas_core.semaphore_dtype): pass +class dma_semaphore(pallas_core.semaphore_dtype): + pass + class DMASemaphore(pallas_core.AbstractSemaphoreTy): type = dma_semaphore name = "dma_sem" + class SemaphoreType(enum.Enum): REGULAR = "regular" DMA = "dma" @@ -231,8 +236,9 @@ def __call__(self, shape: tuple[int, ...]): dtype = pallas_core.BarrierSemaphore() else: dtype = pallas_core.Semaphore() - return pallas_core.MemoryRef(jax_core.ShapedArray(shape, dtype), - MemorySpace.SEMAPHORE) + return pallas_core.MemoryRef( + jax_core.ShapedArray(shape, dtype), MemorySpace.SEMAPHORE + ) def get_array_aval(self) -> pallas_core.ShapedArrayWithMemorySpace: return self(()).get_array_aval() @@ -240,6 +246,7 @@ def get_array_aval(self) -> pallas_core.ShapedArrayWithMemorySpace: def get_ref_aval(self) -> state.AbstractRef: return self(()).get_ref_aval() + @dataclasses.dataclass(frozen=True) class AbstractSemaphore(jax_core.AbstractValue): sem_type: SemaphoreType @@ -255,15 +262,16 @@ def __init__( grid: pallas_core.Grid = (), in_specs: pallas_core.BlockSpecTree = no_block_spec, out_specs: pallas_core.BlockSpecTree = no_block_spec, - scratch_shapes: pallas_core.ScratchShapeTree = () + scratch_shapes: pallas_core.ScratchShapeTree = (), ): super().__init__(grid, in_specs, out_specs, scratch_shapes) self.num_scalar_prefetch = num_scalar_prefetch self.scratch_shapes = tuple(scratch_shapes) def _make_scalar_ref_aval(self, aval): - return state.AbstractRef(jax_core.ShapedArray(aval.shape, aval.dtype), - MemorySpace.SMEM) + return state.AbstractRef( + jax_core.ShapedArray(aval.shape, aval.dtype), MemorySpace.SMEM + ) @dataclasses.dataclass(frozen=True) @@ -274,6 +282,7 @@ class TensorCore: @dataclasses.dataclass(frozen=True) class TensorCoreMesh: """A mesh of TensorCores.""" + devices: np.ndarray axis_names: Sequence[str] @@ -315,7 +324,7 @@ def create_tensorcore_mesh( num_cores: int | None = None, ) -> TensorCoreMesh: if devices is not None and num_cores is not None: - raise ValueError('cannot specify both devices and num_cores') + raise ValueError("cannot specify both devices and num_cores") if num_cores is None: if devices is None: abstract_device = jax.sharding.get_abstract_mesh().abstract_device @@ -330,6 +339,120 @@ def create_tensorcore_mesh( ) +def pass_scalars_as_refs( + jaxpr: jax_core.Jaxpr, + args: Sequence[Any], + in_avals: Sequence[jax_core.AbstractValue], + out_avals: Sequence[jax_core.AbstractValue], + mesh, + copy_to_smem: bool = False, +) -> tuple[ + jax_core.Jaxpr, + tuple[Any, ...], + tuple[jax_core.AbstractValue, ...], + tuple[jax_core.AbstractValue, ...], + tuple[bool, ...], +]: + """Rewrites a jaxpr to pass scalars as refs instead of values.""" + def allowed_aval(aval): + if isinstance(aval, state.AbstractRef): + return True + if isinstance(aval, jax_core.ShapedArray): + # Only scalars are allowed. + return not aval.shape + return False + + assert all(allowed_aval(v.aval) for v in jaxpr.constvars + jaxpr.invars) + + is_scalar_const = [ + isinstance(v.aval, jax_core.ShapedArray) and not v.aval.shape + for v in jaxpr.constvars + ] + if not any(is_scalar_const): + return ( + jaxpr, + tuple(in_avals), + tuple(out_avals), + tuple(args), + tuple(is_scalar_const), + ) + non_scalar_const_avals, scalar_const_avals = util.partition_list( + is_scalar_const, + [v.aval for v in jaxpr.constvars], + ) + non_scalar_consts, scalar_consts = util.partition_list( + is_scalar_const, args + ) + if copy_to_smem: + smem_alloc = [ + state.AbstractRef( + jax_core.ShapedArray((1,), aval.dtype), # pyrefly: ignore[missing-attribute] + memory_space=MemorySpace.SMEM, + ) + for aval in scalar_const_avals + ] + else: + smem_alloc = [] + + # Rewrite body jaxpr to take in scalar values as Refs. + def new_body(*args): + scalar_const_refs, non_scalar_const_refs, args = util.split_list( + args, [len(scalar_consts), len(non_scalar_consts)] + ) + if copy_to_smem: + smem, args = util.split_list(args, [len(smem_alloc)]) + assert len(smem) == len(scalar_const_refs) + from jax._src.pallas.mosaic.helpers import sync_copy + + sync_copy(scalar_const_refs, smem) + else: + smem = scalar_const_refs + scalar_const_values = [s[0] for s in smem] + new_consts = util.merge_lists( + is_scalar_const, non_scalar_const_refs, scalar_const_values + ) + return jax_core.eval_jaxpr(jaxpr, new_consts, *args) + + # TODO(sharadmv): Remove this once Mosaic support passing scalars as values. + scalar_const_trace_avals = [ + state.AbstractRef( + jax_core.ShapedArray((1,), aval.dtype), # pyrefly: ignore[missing-attribute] + memory_space=MemorySpace.HBM if copy_to_smem else MemorySpace.SMEM, + ) + for aval in scalar_const_avals + ] + new_trace_avals = [ + *scalar_const_trace_avals, + *non_scalar_const_avals, + *smem_alloc, + *[v.aval for v in jaxpr.invars], + ] + with ( + pallas_core.tracing_grid_env( + tuple(mesh.shape.values()), mapped_dims=() + ), + jax_core.extend_axis_env_nd(mesh.shape.items()), + ): + new_jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( + lu.wrap_init( + new_body, debug_info=jaxpr.debug_info.with_unknown_names() + ), + new_trace_avals, + ) + jaxpr = new_jaxpr.replace( + constvars=new_jaxpr.invars[: len(jaxpr.constvars)], + invars=new_jaxpr.invars[len(jaxpr.constvars) :], + ) + args = [ + *[a[None] for a in scalar_consts], + *non_scalar_consts, + ] + in_avals, out_avals, _ = util.split_list( + new_trace_avals, [len(in_avals), len(out_avals)] + ) + return jaxpr, tuple(in_avals), tuple(out_avals), tuple(args), tuple(is_scalar_const) + + def _tensorcore_mesh_discharge_rule( in_avals, out_avals, @@ -345,17 +468,13 @@ def _tensorcore_mesh_discharge_rule( ): assert isinstance(mesh, TensorCoreMesh) if compiler_params and not isinstance(compiler_params, CompilerParams): - raise ValueError( - "compiler_params must be a pltpu.CompilerParams" - ) + raise ValueError("compiler_params must be a pltpu.CompilerParams") if not compiler_params: compiler_params = CompilerParams() if len(mesh.shape) > 1: raise NotImplementedError("Mesh must be 1D") if compiler_params.dimension_semantics is not None: - raise ValueError( - "dimension_semantics must be None for TensorCoreMesh" - ) + raise ValueError("dimension_semantics must be None for TensorCoreMesh") num_cores = len(mesh.devices) if num_cores > 1: # Since each core will have its own VMEM, we currently disallow VMEM inputs @@ -369,54 +488,10 @@ def _tensorcore_mesh_discharge_rule( "TensorCoreMesh does not support VMEM inputs/outputs when there are" " >1 cores. Use HBM or ANY instead." ) - def allowed_aval(aval): - if isinstance(aval, state.AbstractRef): - return True - if isinstance(aval, jax_core.ShapedArray): - # Only scalars are allowed. - return not aval.shape - return False - assert all(allowed_aval(v.aval) for v in jaxpr.constvars + jaxpr.invars) - - is_scalar_const = [ - isinstance(v.aval, jax_core.ShapedArray) and not v.aval.shape - for v in jaxpr.constvars - ] - if any(is_scalar_const): - # Rewrite body jaxpr to take in scalar values as Refs. - def new_body(*args): - args = [ - a[0] if is_scalar else a - for a, is_scalar in zip(args, is_scalar_const) - ] - return jax_core.eval_jaxpr(jaxpr, args) - # TODO(sharadmv): Remove this once Mosaic support passing scalars as values. - new_trace_avals = [ - state.AbstractRef( # pylint: disable=g-long-ternary - jax_core.ShapedArray((1,), v.aval.dtype), - memory_space=MemorySpace.SMEM, - ) - if is_scalar - else v.aval - for v, is_scalar in zip(jaxpr.constvars, is_scalar_const) - ] - with ( - pallas_core.tracing_grid_env(tuple(mesh.shape.values()), mapped_dims=()), - jax_core.extend_axis_env_nd(mesh.shape.items()), - ): - new_jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( - lu.wrap_init( - new_body, debug_info=jaxpr.debug_info.with_unknown_names() - ), - new_trace_avals, - ) - jaxpr = new_jaxpr.replace(invars=[], constvars=new_jaxpr.invars) - args = tuple( - a[None] if is_scalar else a - for a, is_scalar in zip(args, is_scalar_const) - ) - in_avals, out_avals = util.split_list(new_trace_avals, [len(in_avals)]) - return pallas_core.default_mesh_discharge_rule( + jaxpr, in_avals, out_avals, args, is_scalar_const = pass_scalars_as_refs( + jaxpr, args, in_avals, out_avals, mesh + ) + refs_out, out = pallas_core.default_mesh_discharge_rule( in_avals, out_avals, *args, @@ -429,6 +504,11 @@ def new_body(*args): name=name, metadata=metadata, ) + refs_out = [ + a if not is_scalar else None + for is_scalar, a in zip(is_scalar_const, refs_out) + ] + return refs_out, out pallas_core._core_map_mesh_rules[TensorCoreMesh] = ( @@ -452,6 +532,7 @@ def get_device_kind() -> str: return abstract_device.device_kind return jex_backend.get_default_device().device_kind + def get_num_device_cores() -> int: if abstract_device := jax.sharding.get_abstract_mesh().abstract_device: return abstract_device.num_cores diff --git a/jax/_src/pallas/mosaic/interpret/BUILD b/jax/_src/pallas/mosaic/interpret/BUILD index 2a86f2258032..9fcde71ee532 100644 --- a/jax/_src/pallas/mosaic/interpret/BUILD +++ b/jax/_src/pallas/mosaic/interpret/BUILD @@ -44,7 +44,9 @@ py_library( "//jax/_src:frozen_dict", "//jax/_src:lax", "//jax/_src:mlir", + "//jax/_src:partial_eval", "//jax/_src:source_info_util", + "//jax/_src:tree_util", "//jax/_src:typing", "//jax/_src:util", "//jax/_src/pallas", @@ -64,13 +66,17 @@ pytype_strict_library( srcs = ["shared_memory.py"], deps = [ ":race_detection_state", + ":utils", ":vector_clock", "//jax", "//jax/_src:source_info_util", "//jax/_src:typing", "//jax/_src/pallas", "//jax/_src/pallas/mosaic:core", - ] + py_deps("numpy"), + ] + py_deps([ + "absl/logging", + "numpy", + ]), ) pytype_strict_library( @@ -97,6 +103,7 @@ pytype_strict_library( deps = [ "//jax", "//jax/_src:core", + "//jax/_src:source_info_util", "//jax/_src:util", "//jax/_src/pallas", ] + py_deps("numpy"), diff --git a/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py b/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py index f0ceaa306ee2..ea15251c79db 100644 --- a/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py +++ b/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py @@ -140,7 +140,7 @@ def force_tpu_interpret_mode(params: InterpretParams = InterpretParams()): config.pallas_tpu_interpret_mode_context_manager.set_local(prev) def set_tpu_interpret_mode(params: InterpretParams = InterpretParams()): - config.pallas_tpu_interpret_mode_context_manager.set_global(params) # type: ignore[arg-type] + config.pallas_tpu_interpret_mode_context_manager.set_global(params) # TODO(jburnim): Do we want to support multiple instances of SharedMemory? @@ -212,6 +212,7 @@ def _initialize_shared_memory( clean_up_barrier=threading.Barrier( num_devices, action=_clear_shared_memory ), + logging_mode=interpret_params.logging_mode, ) assert _shared_memory.num_cores == num_cores @@ -303,6 +304,7 @@ def _allocate_buffer( local_core_id: Array | None, memory_space: Array, val: Array, + source_info: source_info_util.SourceInfo | None = None, ): """Allocates a memory buffer on the device with id `device_id` and core with id `local_core_id`. @@ -316,6 +318,7 @@ def _allocate_buffer( buffer in. If the corresponding memory space is "any" (i.e. HBM), at most one buffer will be allocated and it will belong to (local) core id 0. val: Array of values to initialize the allocated buffer with. + source_info: Information about the source code location of the allocation. Returns: Integer id for the allocated buffer. @@ -356,7 +359,14 @@ def _allocate_buffer( val = val.copy() shared_memory.allocate_buffer( - key, ref_count=ref_count, value=np.array(val) + key, + ref_count=ref_count, + value=np.array(val), + logging_info=interpret_utils.LoggingInfo( + device_id=device_id, + local_core_id=lci, + source_info=source_info, + ), ) local_core_id_to_buffer_id[lci] = buffer_id @@ -375,7 +385,9 @@ def _local_core_id_or_zero_if_hbm(local_core_id: int, memory_space: str) -> int: return local_core_id -def _deallocate_buffer(device_id, local_core_id, memory_space, buffer_id): +def _deallocate_buffer( + device_id, local_core_id, memory_space, buffer_id, source_info=None +): device_id = int(device_id) local_core_id = int(local_core_id) memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] @@ -385,7 +397,14 @@ def _deallocate_buffer(device_id, local_core_id, memory_space, buffer_id): shared_memory = _get_shared_memory() key = (memory_space, buffer_id, device_id, local_core_id) - shared_memory.deallocate_buffer(key) + shared_memory.deallocate_buffer( + key, + logging_info=interpret_utils.LoggingInfo( + device_id=device_id, + local_core_id=local_core_id, + source_info=source_info, + ), + ) def _allocate_semaphores( @@ -532,7 +551,14 @@ def get( key = (memory_space, buffer_id, device_id, local_core_id_for_buffer) read_range = interpret_utils.to_range(transforms) ret, (shape, dtype), clock_ = shared_memory.get_buffer_content( - key, read_range, global_core_id + key, + read_range, + global_core_id, + logging_info=interpret_utils.LoggingInfo( + device_id=device_id, + local_core_id=local_core_id, + source_info=source_info, + ), ) clock = clock if clock is not None else clock_ @@ -657,7 +683,15 @@ def store( key = (memory_space, buffer_id, device_id, local_core_id_for_buffer) write_range = interpret_utils.to_range(transforms) in_bounds, (shape, _), clock_ = shared_memory.store_buffer_content( - key, write_range, val, global_core_id + key, + write_range, + val, + global_core_id, + logging_info=interpret_utils.LoggingInfo( + device_id=device_id, + local_core_id=local_core_id, + source_info=source_info, + ), ) clock = clock if clock is not None else clock_ @@ -728,7 +762,16 @@ def swap( key = (memory_space, buffer_id, device_id, local_core_id_for_buffer) read_write_range = interpret_utils.to_range(transforms) ret, (shape, _), clock = shared_memory.swap_buffer_content( - key, read_write_range, val, mask, global_core_id + key, + read_write_range, + val, + mask, + global_core_id, + logging_info=interpret_utils.LoggingInfo( + device_id=device_id, + local_core_id=local_core_id, + source_info=source_info, + ), ) if ret is None: @@ -847,7 +890,12 @@ def execute_read(self): # Signal the send semaphore. if self.src_sem is not None: self.src_sem.signal( - self.data_size, self.src_global_core_id, clock=self.clock + self.data_size, self.src_global_core_id, clock=self.clock, + logging_info=interpret_utils.LoggingInfo( + device_id=self.src_device_id, + local_core_id=self.src_local_core_id, + source_info=self.source_info, + ), ) self.state = DmaState.READ @@ -887,7 +935,12 @@ def execute_write(self): vc.inc_vector_clock(self.clock, self.virtual_device_id) self.dst_sem.signal( - self.data_size, self.dst_global_core_id, clock=self.clock + self.data_size, self.dst_global_core_id, clock=self.clock, + logging_info=interpret_utils.LoggingInfo( + device_id=self.dst_device_id, + local_core_id=self.dst_local_core_id, + source_info=self.source_info, + ), ) self.data = None @@ -993,7 +1046,7 @@ def dma_start( dma.execute_read_and_write() -def dma_wait(device_id, local_core_id, sem_id, size): +def dma_wait(device_id, local_core_id, sem_id, size, source_info=None): shared_memory = _get_shared_memory() device_id = int(device_id) @@ -1007,7 +1060,15 @@ def dma_wait(device_id, local_core_id, sem_id, size): [sem_id], global_core_id ) assert sem is not None - sem.wait(size, global_core_id, has_tasks=True) + sem.wait( + size, + global_core_id, + has_tasks=True, + logging_info=interpret_utils.LoggingInfo( + device_id=device_id, local_core_id=local_core_id, + source_info=source_info, + ), + ) def semaphore_signal( @@ -1017,6 +1078,7 @@ def semaphore_signal( inc, target_device_id, target_local_core_id, + source_info=None, ): shared_memory = _get_shared_memory() @@ -1042,10 +1104,15 @@ def semaphore_signal( inc, shared_memory.get_global_core_id(target_device_id, target_local_core_id), clock, + logging_info=interpret_utils.LoggingInfo( + device_id=device_id, + local_core_id=local_core_id, + source_info=source_info, + ), ) -def semaphore_wait(device_id, local_core_id, sem_id, value): +def semaphore_wait(device_id, local_core_id, sem_id, value, source_info=None): shared_memory = _get_shared_memory() device_id = int(device_id) @@ -1058,7 +1125,15 @@ def semaphore_wait(device_id, local_core_id, sem_id, value): [sem_id], global_core_id ) assert sem is not None - sem.wait(value, global_core_id) + sem.wait( + value, + global_core_id, + logging_info=interpret_utils.LoggingInfo( + device_id=device_id, + local_core_id=local_core_id, + source_info=source_info, + ), + ) _SEMAPHORE = mosaic_core.MemorySpace.SEMAPHORE @@ -1271,7 +1346,9 @@ def f(*args, jaxpr): memory_space = _forward_any_to_hbm(v.aval.memory_space) allocs.append( callback.io_callback( - _allocate_buffer, + functools.partial( + _allocate_buffer, source_info=eqn.source_info + ), jax.ShapeDtypeStruct((), jnp.int16), device_id, local_core_id, @@ -1297,7 +1374,9 @@ def f(*args, jaxpr): pass else: callback.io_callback( - _deallocate_buffer, + functools.partial( + _deallocate_buffer, source_info=eqn.source_info + ), None, device_id, local_core_id, @@ -1423,7 +1502,7 @@ def f(*args, jaxpr): read_shape = src_ref_aval.shape read_dtype = src_ref_aval.dtype callback.io_callback( - dma_wait, + functools.partial(dma_wait, source_info=eqn.source_info), (), device_id, local_core_id, @@ -1449,7 +1528,7 @@ def f(*args, jaxpr): target_device_id, eqn.params['device_id_type'], axis_sizes, axis_indices) callback.io_callback( - semaphore_signal, + functools.partial(semaphore_signal, source_info=eqn.source_info), (), device_id, local_core_id, @@ -1477,12 +1556,6 @@ def f(*args, jaxpr): ) out = [] - elif prim is primitives.atomic_rmw_p: - raise NotImplementedError('atomic_rmw_p') - - elif prim is primitives.atomic_cas_p: - raise NotImplementedError('atomic_cas_p') - else: if interpret_params.skip_floating_point_ops and all( interpret_utils.is_float(ovar.aval.dtype) for ovar in eqn.outvars @@ -1714,7 +1787,7 @@ def interpret_pallas_call( mosaic_params = mosaic_core.CompilerParams() else: assert isinstance(compiler_params, mosaic_core.CompilerParams) - mosaic_params = compiler_params # type: ignore[assignment] + mosaic_params = compiler_params del compiler_params args = [remove_memory_space_p.bind(a) for a in args] @@ -1845,15 +1918,15 @@ def interpret_pallas_call( is_input = i < grid_mapping.num_inputs is_output = (output_idx >= 0) and (output_idx < grid_mapping.num_outputs) aval = var.aval - memory_space = _forward_any_to_hbm(aval.memory_space) + memory_space = _forward_any_to_hbm(aval.memory_space) # pyrefly: ignore[missing-attribute] if memory_space is _SEMAPHORE: kernel_buffer_ids.append( callback.io_callback( _allocate_semaphores, - jax.ShapeDtypeStruct(aval.shape, jnp.int16), + jax.ShapeDtypeStruct(aval.shape, jnp.int16), # pyrefly: ignore[missing-attribute] device_id, None, # local_core_id - aval.shape, + aval.shape, # pyrefly: ignore[missing-attribute] ordered=True, ) ) @@ -1877,7 +1950,7 @@ def interpret_pallas_call( None, # local_core_id, TPU_MEMORY_SPACE_IDXS[memory_space], interpret_params.get_uninitialized_array( - var.aval.shape, var.aval.dtype + var.aval.shape, var.aval.dtype # pyrefly: ignore[missing-attribute] ), ordered=True, ) @@ -2156,7 +2229,7 @@ def _store_to_output_buffer(index, output_var, transform): assert len(next_start_indices[num_inputs + j].shape) == 1 transform = indexing.NDIndexer( indices=tuple( - indexing.ds(st, sz) if not iid else st # type: ignore[misc] + indexing.ds(st, sz) if not iid else st # pyrefly: ignore[bad-argument-type] for st, sz, iid in zip( cur_start_indices[num_inputs + j], block_shapes[num_inputs + j], diff --git a/jax/_src/pallas/mosaic/interpret/shared_memory.py b/jax/_src/pallas/mosaic/interpret/shared_memory.py index d61250fa16c2..9233f2f1a4f3 100644 --- a/jax/_src/pallas/mosaic/interpret/shared_memory.py +++ b/jax/_src/pallas/mosaic/interpret/shared_memory.py @@ -21,7 +21,9 @@ import threading from typing import Any, Callable, Literal +from absl import logging from jax._src.pallas.mosaic.interpret import vector_clock as vc +import jax._src.pallas.mosaic.interpret.utils as interpret_utils import numpy as np @@ -31,9 +33,11 @@ def __init__( self, shared_memory: SharedMemory, semaphore_id: int, + enable_logging: bool = False, ): self.shared_memory = shared_memory self.id: int = semaphore_id + self.enable_logging: bool = enable_logging # TODO(jburnim): Use one Condition variable per device. (Which will be # easier to do when we're using single integer device IDs.) @@ -66,10 +70,28 @@ def detect_races(self) -> bool: def dma_execution_mode(self) -> str: return self.shared_memory.dma_execution_mode + def _log(self, message: str): + """Logs a message to `absl.logging`. To be called while holding the lock on `self.cv`.""" + # Log every line separately to make sure `absl.logging` adds the correct + # prefix (i.e. I***